diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnectionFactory.java index e0ba84ab895..021eef155ee 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnectionFactory.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnectionFactory.java @@ -20,6 +20,7 @@ package org.eclipse.jetty.server; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; import org.eclipse.jetty.io.AbstractConnection; @@ -76,6 +77,26 @@ public abstract class AbstractConnectionFactory extends ContainerLifeCycle imple _inputbufferSize = size; } + protected String findNextProtocol(Connector connector) + { + return findNextProtocol(connector, getProtocol()); + } + + protected static String findNextProtocol(Connector connector, String currentProtocol) + { + String nextProtocol = null; + for (Iterator it = connector.getProtocols().iterator(); it.hasNext(); ) + { + String protocol = it.next(); + if (currentProtocol.equalsIgnoreCase(protocol)) + { + nextProtocol = it.hasNext() ? it.next() : null; + break; + } + } + return nextProtocol; + } + protected AbstractConnection configure(AbstractConnection connection, Connector connector, EndPoint endPoint) { connection.setInputBufferSize(getInputBufferSize()); diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnector.java b/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnector.java index 51624b0fceb..e9dfdcf08ba 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnector.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/AbstractConnector.java @@ -34,6 +34,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.locks.Condition; +import java.util.stream.Collectors; import org.eclipse.jetty.io.ArrayByteBufferPool; import org.eclipse.jetty.io.ByteBufferPool; @@ -798,9 +799,9 @@ public abstract class AbstractConnector extends ContainerLifeCycle implements Co @Override public String toString() { - return String.format("%s@%x{%s,%s}", + return String.format("%s@%x{%s, %s}", _name == null ? getClass().getSimpleName() : _name, hashCode(), - getDefaultProtocol(), getProtocols()); + getDefaultProtocol(), getProtocols().stream().collect(Collectors.joining(", ", "(", ")"))); } } diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/ConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/ConnectionFactory.java index 04d103ac043..ebba5f00a6c 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/ConnectionFactory.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/ConnectionFactory.java @@ -18,6 +18,7 @@ package org.eclipse.jetty.server; +import java.nio.ByteBuffer; import java.util.List; import org.eclipse.jetty.http.BadMessageException; @@ -85,4 +86,43 @@ public interface ConnectionFactory */ public Connection upgradeConnection(Connector connector, EndPoint endPoint, MetaData.Request upgradeRequest, HttpFields responseFields) throws BadMessageException; } + + /** + *

Connections created by this factory MUST implement {@link Connection.UpgradeTo}.

+ */ + interface Detecting extends ConnectionFactory + { + /** + * The possible outcomes of the {@link #detect(ByteBuffer)} method. + */ + enum Detection + { + /** + * A {@link Detecting} can work with the given bytes. + */ + RECOGNIZED, + /** + * A {@link Detecting} cannot work with the given bytes. + */ + NOT_RECOGNIZED, + /** + * A {@link Detecting} requires more bytes to make a decision. + */ + NEED_MORE_BYTES + } + + /** + *

Check the bytes in the given {@code buffer} to figure out if this {@link Detecting} instance + * can work with them or not.

+ *

The {@code buffer} MUST be left untouched by this method: bytes MUST NOT be consumed and MUST NOT be modified.

+ * @param buffer the buffer. + * @return One of: + * + */ + Detection detect(ByteBuffer buffer); + } } diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/DetectorConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/DetectorConnectionFactory.java new file mode 100644 index 00000000000..74867150ffe --- /dev/null +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/DetectorConnectionFactory.java @@ -0,0 +1,301 @@ +// +// ======================================================================== +// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under +// the terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0 +// +// This Source Code may also be made available under the following +// Secondary Licenses when the conditions for such availability set +// forth in the Eclipse Public License, v. 2.0 are satisfied: +// the Apache License v2.0 which is available at +// https://www.apache.org/licenses/LICENSE-2.0 +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.server; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.stream.Collectors; + +import org.eclipse.jetty.io.AbstractConnection; +import org.eclipse.jetty.io.Connection; +import org.eclipse.jetty.io.EndPoint; +import org.eclipse.jetty.util.BufferUtil; +import org.eclipse.jetty.util.log.Log; +import org.eclipse.jetty.util.log.Logger; + +/** + * A {@link ConnectionFactory} combining multiple {@link Detecting} instances that will upgrade to + * the first one recognizing the bytes in the buffer. + */ +public class DetectorConnectionFactory extends AbstractConnectionFactory implements ConnectionFactory.Detecting +{ + private static final Logger LOG = Log.getLogger(DetectorConnectionFactory.class); + + private final List _detectingConnectionFactories; + + /** + *

When the first bytes are not recognized by the {@code detectingConnectionFactories}, the default behavior is to + * upgrade to the protocol returned by {@link #findNextProtocol(Connector)}.

+ * @param detectingConnectionFactories the {@link Detecting} instances. + */ + public DetectorConnectionFactory(Detecting... detectingConnectionFactories) + { + super(toProtocolString(detectingConnectionFactories)); + _detectingConnectionFactories = Arrays.asList(detectingConnectionFactories); + for (Detecting detectingConnectionFactory : detectingConnectionFactories) + { + addBean(detectingConnectionFactory); + } + } + + private static String toProtocolString(Detecting... detectingConnectionFactories) + { + if (detectingConnectionFactories.length == 0) + throw new IllegalArgumentException("At least one detecting instance is required"); + + // remove protocol duplicates while keeping their ordering -> use LinkedHashSet + LinkedHashSet protocols = Arrays.stream(detectingConnectionFactories).map(ConnectionFactory::getProtocol).collect(Collectors.toCollection(LinkedHashSet::new)); + + String protocol = protocols.stream().collect(Collectors.joining("|", "[", "]")); + if (LOG.isDebugEnabled()) + LOG.debug("Detector generated protocol name : {}", protocol); + return protocol; + } + + /** + * Performs a detection using multiple {@link ConnectionFactory.Detecting} instances and returns the aggregated outcome. + * @param buffer the buffer to perform a detection against. + * @return A {@link Detecting.Detection} value with the detection outcome of the {@code detectingConnectionFactories}. + */ + @Override + public Detecting.Detection detect(ByteBuffer buffer) + { + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} detecting from buffer {} using {}", getProtocol(), BufferUtil.toHexString(buffer), _detectingConnectionFactories); + boolean needMoreBytes = true; + for (Detecting detectingConnectionFactory : _detectingConnectionFactories) + { + Detecting.Detection detection = detectingConnectionFactory.detect(buffer); + if (detection == Detecting.Detection.RECOGNIZED) + { + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} recognized bytes using {}", getProtocol(), detection); + return Detecting.Detection.RECOGNIZED; + } + needMoreBytes &= detection == Detection.NEED_MORE_BYTES; + } + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} {}", getProtocol(), (needMoreBytes ? "requires more bytes" : "failed to recognize bytes")); + return needMoreBytes ? Detection.NEED_MORE_BYTES : Detection.NOT_RECOGNIZED; + } + + /** + * Utility method that performs an upgrade to the specified connection factory, disposing of the given resources when needed. + * @param connectionFactory the connection factory to upgrade to. + * @param connector the connector. + * @param endPoint the endpoint. + */ + protected static void upgradeToConnectionFactory(ConnectionFactory connectionFactory, Connector connector, EndPoint endPoint) throws IllegalStateException + { + if (LOG.isDebugEnabled()) + LOG.debug("Upgrading to connection factory {}", connectionFactory); + if (connectionFactory == null) + throw new IllegalStateException("Cannot upgrade: connection factory must not be null for " + endPoint); + Connection nextConnection = connectionFactory.newConnection(connector, endPoint); + if (!(nextConnection instanceof Connection.UpgradeTo)) + throw new IllegalStateException("Cannot upgrade: " + nextConnection + " does not implement " + Connection.UpgradeTo.class.getName() + " for " + endPoint); + endPoint.upgrade(nextConnection); + if (LOG.isDebugEnabled()) + LOG.debug("Upgraded to connection factory {} and released buffer", connectionFactory); + } + + /** + *

Callback method called when detection was unsuccessful. + * This implementation upgrades to the protocol returned by {@link #findNextProtocol(Connector)}.

+ * @param connector the connector. + * @param endPoint the endpoint. + * @param buffer the buffer. + */ + protected void nextProtocol(Connector connector, EndPoint endPoint, ByteBuffer buffer) throws IllegalStateException + { + String nextProtocol = findNextProtocol(connector); + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} detection unsuccessful, found '{}' as the next protocol to upgrade to", getProtocol(), nextProtocol); + if (nextProtocol == null) + throw new IllegalStateException("Cannot find protocol following '" + getProtocol() + "' in connector's protocol list " + connector.getProtocols() + " for " + endPoint); + upgradeToConnectionFactory(connector.getConnectionFactory(nextProtocol), connector, endPoint); + } + + @Override + public Connection newConnection(Connector connector, EndPoint endPoint) + { + return configure(new DetectorConnection(endPoint, connector), connector, endPoint); + } + + private class DetectorConnection extends AbstractConnection implements Connection.UpgradeFrom, Connection.UpgradeTo + { + private final Connector _connector; + private final ByteBuffer _buffer; + + private DetectorConnection(EndPoint endp, Connector connector) + { + super(endp, connector.getExecutor()); + _connector = connector; + _buffer = connector.getByteBufferPool().acquire(getInputBufferSize(), true); + } + + @Override + public void onUpgradeTo(ByteBuffer prefilled) + { + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} copying prefilled buffer {}", getProtocol(), BufferUtil.toDetailString(prefilled)); + if (BufferUtil.hasContent(prefilled)) + BufferUtil.append(_buffer, prefilled); + } + + @Override + public ByteBuffer onUpgradeFrom() + { + return _buffer; + } + + @Override + public void onOpen() + { + super.onOpen(); + if (!detectAndUpgrade()) + fillInterested(); + } + + @Override + public void onFillable() + { + try + { + while (BufferUtil.space(_buffer) > 0) + { + // Read data + int fill = getEndPoint().fill(_buffer); + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} filled buffer with {} bytes", getProtocol(), fill); + if (fill < 0) + { + _connector.getByteBufferPool().release(_buffer); + getEndPoint().shutdownOutput(); + return; + } + if (fill == 0) + { + fillInterested(); + return; + } + + if (detectAndUpgrade()) + return; + } + + // all Detecting instances want more bytes than this buffer can store + LOG.warn("Detector {} failed to detect upgrade target on {} for {}", getProtocol(), _detectingConnectionFactories, getEndPoint()); + releaseAndClose(); + } + catch (Throwable x) + { + LOG.warn("Detector {} error for {}", getProtocol(), getEndPoint(), x); + releaseAndClose(); + } + } + + /** + * @return true when upgrade was performed, false otherwise. + */ + private boolean detectAndUpgrade() + { + if (BufferUtil.isEmpty(_buffer)) + { + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} skipping detection on an empty buffer", getProtocol()); + return false; + } + + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} performing detection with {} bytes", getProtocol(), _buffer.remaining()); + boolean notRecognized = true; + for (Detecting detectingConnectionFactory : _detectingConnectionFactories) + { + Detecting.Detection detection = detectingConnectionFactory.detect(_buffer); + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} performed detection from {} with {} which returned {}", getProtocol(), BufferUtil.toDetailString(_buffer), detectingConnectionFactory, detection); + if (detection == Detecting.Detection.RECOGNIZED) + { + try + { + // This DetectingConnectionFactory recognized those bytes -> upgrade to the next one. + Connection nextConnection = detectingConnectionFactory.newConnection(_connector, getEndPoint()); + if (!(nextConnection instanceof UpgradeTo)) + throw new IllegalStateException("Cannot upgrade: " + nextConnection + " does not implement " + UpgradeTo.class.getName()); + getEndPoint().upgrade(nextConnection); + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} upgraded to {}", getProtocol(), nextConnection); + return true; + } + catch (DetectionFailureException e) + { + // It's just bubbling up from a nested Detector, so it's already handled, just rethrow it. + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} failed to upgrade, rethrowing", getProtocol(), e); + throw e; + } + catch (Exception e) + { + // Two reasons that can make us end up here: + // 1) detectingConnectionFactory.newConnection() failed? probably because it cannot find the next protocol + // 2) nextConnection is not instanceof UpgradeTo + // -> release the resources before rethrowing as DetectionFailureException + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} failed to upgrade", getProtocol()); + releaseAndClose(); + throw new DetectionFailureException(e); + } + } + notRecognized &= detection == Detecting.Detection.NOT_RECOGNIZED; + } + + if (notRecognized) + { + // No DetectingConnectionFactory recognized those bytes -> call unsuccessful detection callback. + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} failed to detect a known protocol, falling back to nextProtocol()", getProtocol()); + nextProtocol(_connector, getEndPoint(), _buffer); + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} call to nextProtocol() succeeded, assuming upgrade performed", getProtocol()); + return true; + } + + return false; + } + + private void releaseAndClose() + { + if (LOG.isDebugEnabled()) + LOG.debug("Detector {} releasing buffer and closing", getProtocol()); + _connector.getByteBufferPool().release(_buffer); + close(); + } + } + + private static class DetectionFailureException extends RuntimeException + { + public DetectionFailureException(Throwable cause) + { + super(cause); + } + } +} diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java index 91fd5c9ff9f..911afadfe6b 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/OptionalSslConnectionFactory.java @@ -18,14 +18,10 @@ package org.eclipse.jetty.server; -import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import org.eclipse.jetty.io.AbstractConnection; -import org.eclipse.jetty.io.Connection; import org.eclipse.jetty.io.EndPoint; -import org.eclipse.jetty.util.BufferUtil; import org.eclipse.jetty.util.Callback; import org.eclipse.jetty.util.log.Log; import org.eclipse.jetty.util.log.Logger; @@ -34,62 +30,70 @@ import org.eclipse.jetty.util.log.Logger; *

A ConnectionFactory whose connections detect whether the first bytes are * TLS bytes and upgrades to either a TLS connection or to another configurable * connection.

+ * + * @deprecated Use {@link DetectorConnectionFactory} with a {@link SslConnectionFactory} instead. */ -public class OptionalSslConnectionFactory extends AbstractConnectionFactory +@Deprecated +public class OptionalSslConnectionFactory extends DetectorConnectionFactory { - private static final Logger LOG = Log.getLogger(OptionalSslConnection.class); - private static final int TLS_ALERT_FRAME_TYPE = 0x15; - private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16; - private static final int TLS_MAJOR_VERSION = 3; - - private final SslConnectionFactory sslConnectionFactory; - private final String otherProtocol; + private static final Logger LOG = Log.getLogger(OptionalSslConnectionFactory.class); + private final String _nextProtocol; /** *

Creates a new ConnectionFactory whose connections can upgrade to TLS or another protocol.

- *

If {@code otherProtocol} is {@code null}, and the first bytes are not TLS, then - * {@link #otherProtocol(ByteBuffer, EndPoint)} is called.

* - * @param sslConnectionFactory The SslConnectionFactory to use if the first bytes are TLS - * @param otherProtocol the protocol of the ConnectionFactory to use if the first bytes are not TLS, + * @param sslConnectionFactory The {@link SslConnectionFactory} to use if the first bytes are TLS + * @param nextProtocol the protocol of the {@link ConnectionFactory} to use if the first bytes are not TLS, * or null to explicitly handle the non-TLS case */ - public OptionalSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String otherProtocol) + public OptionalSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String nextProtocol) { - super("ssl|other"); - this.sslConnectionFactory = sslConnectionFactory; - this.otherProtocol = otherProtocol; - } - - @Override - public Connection newConnection(Connector connector, EndPoint endPoint) - { - return configure(new OptionalSslConnection(endPoint, connector), connector, endPoint); + super(sslConnectionFactory); + _nextProtocol = nextProtocol; } /** - * @param buffer The buffer with the first bytes of the connection - * @return whether the bytes seem TLS bytes - */ - protected boolean seemsTLS(ByteBuffer buffer) - { - int tlsFrameType = buffer.get(0) & 0xFF; - int tlsMajorVersion = buffer.get(1) & 0xFF; - return (tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE) && tlsMajorVersion == TLS_MAJOR_VERSION; - } - - /** - *

Callback method invoked when {@code otherProtocol} is {@code null} - * and the first bytes are not TLS.

+ *

Callback method invoked when the detected bytes are not TLS.

*

This typically happens when a client is trying to connect to a TLS * port using the {@code http} scheme (and not the {@code https} scheme).

* + * @param connector The connector object + * @param endPoint The connection EndPoint object + * @param buffer The buffer with the first bytes of the connection + */ + protected void nextProtocol(Connector connector, EndPoint endPoint, ByteBuffer buffer) + { + if (LOG.isDebugEnabled()) + LOG.debug("OptionalSSL TLS detection unsuccessful, attempting to upgrade to {}", _nextProtocol); + if (_nextProtocol != null) + { + ConnectionFactory connectionFactory = connector.getConnectionFactory(_nextProtocol); + if (connectionFactory == null) + throw new IllegalStateException("Cannot find protocol '" + _nextProtocol + "' in connector's protocol list " + connector.getProtocols() + " for " + endPoint); + upgradeToConnectionFactory(connectionFactory, connector, endPoint); + } + else + { + otherProtocol(buffer, endPoint); + } + } + + /** + *

Legacy callback method invoked when {@code nextProtocol} is {@code null} + * and the first bytes are not TLS.

+ *

This typically happens when a client is trying to connect to a TLS + * port using the {@code http} scheme (and not the {@code https} scheme).

+ *

This method is kept around for backward compatibility.

+ * * @param buffer The buffer with the first bytes of the connection * @param endPoint The connection EndPoint object - * @see #seemsTLS(ByteBuffer) + * @deprecated Override {@link #nextProtocol(Connector, EndPoint, ByteBuffer)} instead. */ + @Deprecated protected void otherProtocol(ByteBuffer buffer, EndPoint endPoint) { + LOG.warn("Detected non-TLS bytes, but no other protocol to upgrade to for {}", endPoint); + // There are always at least 2 bytes. int byte1 = buffer.get(0) & 0xFF; int byte2 = buffer.get(1) & 0xFF; @@ -122,105 +126,4 @@ public class OptionalSslConnectionFactory extends AbstractConnectionFactory endPoint.close(); } } - - private class OptionalSslConnection extends AbstractConnection implements Connection.UpgradeFrom - { - private final Connector connector; - private final ByteBuffer buffer; - - public OptionalSslConnection(EndPoint endPoint, Connector connector) - { - super(endPoint, connector.getExecutor()); - this.connector = connector; - this.buffer = BufferUtil.allocateDirect(1536); - } - - @Override - public void onOpen() - { - super.onOpen(); - fillInterested(); - } - - @Override - public void onFillable() - { - try - { - while (true) - { - int filled = getEndPoint().fill(buffer); - if (filled > 0) - { - // Always have at least 2 bytes. - if (BufferUtil.length(buffer) >= 2) - { - upgrade(buffer); - break; - } - } - else if (filled == 0) - { - fillInterested(); - break; - } - else - { - close(); - break; - } - } - } - catch (IOException x) - { - LOG.warn(x); - close(); - } - } - - @Override - public ByteBuffer onUpgradeFrom() - { - return buffer; - } - - private void upgrade(ByteBuffer buffer) - { - if (LOG.isDebugEnabled()) - LOG.debug("Read {}", BufferUtil.toDetailString(buffer)); - - EndPoint endPoint = getEndPoint(); - if (seemsTLS(buffer)) - { - if (LOG.isDebugEnabled()) - LOG.debug("Detected TLS bytes, upgrading to {}", sslConnectionFactory); - endPoint.upgrade(sslConnectionFactory.newConnection(connector, endPoint)); - } - else - { - if (otherProtocol != null) - { - ConnectionFactory connectionFactory = connector.getConnectionFactory(otherProtocol); - if (connectionFactory != null) - { - if (LOG.isDebugEnabled()) - LOG.debug("Detected non-TLS bytes, upgrading to {}", connectionFactory); - Connection next = connectionFactory.newConnection(connector, endPoint); - endPoint.upgrade(next); - } - else - { - LOG.warn("Missing {} {} in {}", otherProtocol, ConnectionFactory.class.getSimpleName(), connector); - close(); - } - } - else - { - if (LOG.isDebugEnabled()) - LOG.debug("Detected non-TLS bytes, but no other protocol to upgrade to"); - otherProtocol(buffer, endPoint); - } - } - } - } } diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/ProxyConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/ProxyConnectionFactory.java index fb7dc99d622..2460b5803eb 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/ProxyConnectionFactory.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/ProxyConnectionFactory.java @@ -27,7 +27,6 @@ import java.nio.ByteBuffer; import java.nio.channels.ReadPendingException; import java.nio.channels.WritePendingException; import java.nio.charset.StandardCharsets; -import java.util.Iterator; import org.eclipse.jetty.io.AbstractConnection; import org.eclipse.jetty.io.Connection; @@ -46,252 +45,283 @@ import org.eclipse.jetty.util.log.Logger; * * @see http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt */ -public class ProxyConnectionFactory extends AbstractConnectionFactory +public class ProxyConnectionFactory extends DetectorConnectionFactory { - private static final Logger LOG = Log.getLogger(ProxyConnectionFactory.class); public static final String TLS_VERSION = "TLS_VERSION"; + private static final Logger LOG = Log.getLogger(ProxyConnectionFactory.class); - private final String _next; - private int _maxProxyHeader = 1024; - - /** - * Proxy Connection Factory that uses the next ConnectionFactory - * on the connector as the next protocol - */ public ProxyConnectionFactory() { - super("proxy"); - _next = null; + this(null); } public ProxyConnectionFactory(String nextProtocol) { - super("proxy"); - _next = nextProtocol; + super(new ProxyV1ConnectionFactory(nextProtocol), new ProxyV2ConnectionFactory(nextProtocol)); + } + + private static ConnectionFactory findNextConnectionFactory(String nextProtocol, Connector connector, String currentProtocol, EndPoint endp) + { + currentProtocol = "[" + currentProtocol + "]"; + if (LOG.isDebugEnabled()) + LOG.debug("finding connection factory following {} for protocol {}", currentProtocol, nextProtocol); + String nextProtocolToFind = nextProtocol; + if (nextProtocol == null) + nextProtocolToFind = AbstractConnectionFactory.findNextProtocol(connector, currentProtocol); + if (nextProtocolToFind == null) + throw new IllegalStateException("Cannot find protocol following '" + currentProtocol + "' in connector's protocol list " + connector.getProtocols() + " for " + endp); + ConnectionFactory connectionFactory = connector.getConnectionFactory(nextProtocolToFind); + if (connectionFactory == null) + throw new IllegalStateException("Cannot find protocol '" + nextProtocol + "' in connector's protocol list " + connector.getProtocols() + " for " + endp); + if (LOG.isDebugEnabled()) + LOG.debug("found next connection factory {} for protocol {}", connectionFactory, nextProtocol); + return connectionFactory; } public int getMaxProxyHeader() { - return _maxProxyHeader; + ProxyV2ConnectionFactory v2 = getBean(ProxyV2ConnectionFactory.class); + return v2.getMaxProxyHeader(); } public void setMaxProxyHeader(int maxProxyHeader) { - _maxProxyHeader = maxProxyHeader; + ProxyV2ConnectionFactory v2 = getBean(ProxyV2ConnectionFactory.class); + v2.setMaxProxyHeader(maxProxyHeader); } - @Override - public Connection newConnection(Connector connector, EndPoint endp) + private static class ProxyV1ConnectionFactory extends AbstractConnectionFactory implements Detecting { - String next = _next; - if (next == null) + private static final byte[] SIGNATURE = "PROXY".getBytes(StandardCharsets.US_ASCII); + + private final String _nextProtocol; + + private ProxyV1ConnectionFactory(String nextProtocol) { - for (Iterator i = connector.getProtocols().iterator(); i.hasNext(); ) - { - String p = i.next(); - if (getProtocol().equalsIgnoreCase(p)) - { - next = i.next(); - break; - } - } - } - - return new ProxyProtocolV1orV2Connection(endp, connector, next); - } - - public class ProxyProtocolV1orV2Connection extends AbstractConnection - { - // Only do a tiny read to figure out what PROXY version it is. - private final ByteBuffer _buffer = BufferUtil.allocate(16); - private final Connector _connector; - private final String _next; - - protected ProxyProtocolV1orV2Connection(EndPoint endp, Connector connector, String next) - { - super(endp, connector.getExecutor()); - _connector = connector; - _next = next; + super("proxy"); + this._nextProtocol = nextProtocol; } @Override - public void onOpen() + public Detection detect(ByteBuffer buffer) { - super.onOpen(); - fillInterested(); + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 attempting detection with {} bytes", buffer.remaining()); + if (buffer.remaining() < SIGNATURE.length) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 detection requires more bytes"); + return Detection.NEED_MORE_BYTES; + } + + for (int i = 0; i < SIGNATURE.length; i++) + { + byte signatureByte = SIGNATURE[i]; + byte byteInBuffer = buffer.get(i); + if (byteInBuffer != signatureByte) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 detection unsuccessful"); + return Detection.NOT_RECOGNIZED; + } + } + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 detection succeeded"); + return Detection.RECOGNIZED; } @Override - public void onFillable() + public Connection newConnection(Connector connector, EndPoint endp) { - try + ConnectionFactory nextConnectionFactory = findNextConnectionFactory(_nextProtocol, connector, getProtocol(), endp); + return configure(new ProxyProtocolV1Connection(endp, connector, nextConnectionFactory), connector, endp); + } + + private static class ProxyProtocolV1Connection extends AbstractConnection implements Connection.UpgradeFrom, Connection.UpgradeTo + { + // 0 1 2 3 4 5 6 + // 98765432109876543210987654321 + // PROXY P R.R.R.R L.L.L.L R Lrn + private static final int CR_INDEX = 6; + private static final int LF_INDEX = 7; + + private final Connector _connector; + private final ConnectionFactory _next; + private final ByteBuffer _buffer; + private final StringBuilder _builder = new StringBuilder(); + private final String[] _fields = new String[6]; + private int _index; + private int _length; + + private ProxyProtocolV1Connection(EndPoint endp, Connector connector, ConnectionFactory next) { - while (BufferUtil.space(_buffer) > 0) + super(endp, connector.getExecutor()); + _connector = connector; + _next = next; + _buffer = _connector.getByteBufferPool().acquire(getInputBufferSize(), true); + } + + @Override + public void onFillable() + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 onFillable current index = ", _index); + try { - // Read data - int fill = getEndPoint().fill(_buffer); - if (fill < 0) + while (_index < LF_INDEX) { - getEndPoint().shutdownOutput(); - return; + // Read data + int fill = getEndPoint().fill(_buffer); + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 filled buffer with {} bytes", fill); + if (fill < 0) + { + _connector.getByteBufferPool().release(_buffer); + getEndPoint().shutdownOutput(); + return; + } + if (fill == 0) + { + fillInterested(); + return; + } + + if (parse()) + break; } - if (fill == 0) + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 onFillable parsing done, now upgrading"); + upgrade(); + } + catch (Throwable x) + { + LOG.warn("Proxy v1 error for {}", getEndPoint(), x); + releaseAndClose(); + } + } + + @Override + public void onOpen() + { + super.onOpen(); + + try + { + while (_index < LF_INDEX) { - fillInterested(); - return; + if (!parse()) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 onOpen parsing ran out of bytes, marking as fillInterested"); + fillInterested(); + return; + } + } + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 onOpen parsing done, now upgrading"); + upgrade(); + } + catch (Throwable x) + { + LOG.warn("Proxy v1 error for {}", getEndPoint(), x); + releaseAndClose(); + } + } + + @Override + public ByteBuffer onUpgradeFrom() + { + return _buffer; + } + + @Override + public void onUpgradeTo(ByteBuffer prefilled) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 copying prefilled buffer {}", BufferUtil.toDetailString(prefilled)); + if (BufferUtil.hasContent(prefilled)) + BufferUtil.append(_buffer, prefilled); + } + + /** + * @return true when parsing is done, false when more bytes are needed. + */ + private boolean parse() throws IOException + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 parsing {}", BufferUtil.toDetailString(_buffer)); + _length += _buffer.remaining(); + + // Parse fields + while (_buffer.hasRemaining()) + { + byte b = _buffer.get(); + if (_index < CR_INDEX) + { + if (b == ' ' || b == '\r') + { + _fields[_index++] = _builder.toString(); + _builder.setLength(0); + if (b == '\r') + _index = CR_INDEX; + } + else if (b < ' ') + { + throw new IOException("Proxy v1 bad character " + (b & 0xFF)); + } + else + { + _builder.append((char)b); + } + } + else + { + if (b == '\n') + { + _index = LF_INDEX; + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 parsing is done"); + return true; + } + + throw new IOException("Proxy v1 bad CRLF " + (b & 0xFF)); } } - // Is it a V1? - switch (_buffer.get(0)) - { - case 'P': - { - ProxyProtocolV1Connection v1 = new ProxyProtocolV1Connection(getEndPoint(), _connector, _next, _buffer); - getEndPoint().upgrade(v1); - return; - } - case 0x0D: - { - ProxyProtocolV2Connection v2 = new ProxyProtocolV2Connection(getEndPoint(), _connector, _next, _buffer); - getEndPoint().upgrade(v2); - return; - } - default: - { - LOG.warn("Not PROXY protocol for {}", getEndPoint()); - close(); - break; - } - } + // Not enough bytes. + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 parsing requires more bytes"); + return false; } - catch (Throwable x) + + private void releaseAndClose() { - LOG.warn("PROXY error for " + getEndPoint(), x); + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 releasing buffer and closing"); + _connector.getByteBufferPool().release(_buffer); close(); } - } - } - public static class ProxyProtocolV1Connection extends AbstractConnection - { - // 0 1 2 3 4 5 6 - // 98765432109876543210987654321 - // PROXY P R.R.R.R L.L.L.L R Lrn - - private static final int[] SIZE = {29, 23, 21, 13, 5, 3, 1}; - private final Connector _connector; - private final String _next; - private final StringBuilder _builder = new StringBuilder(); - private final String[] _fields = new String[6]; - private int _index; - private int _length; - - protected ProxyProtocolV1Connection(EndPoint endp, Connector connector, String next, ByteBuffer buffer) - { - super(endp, connector.getExecutor()); - _connector = connector; - _next = next; - _length = buffer.remaining(); - parse(buffer); - } - - @Override - public void onOpen() - { - super.onOpen(); - fillInterested(); - } - - private boolean parse(ByteBuffer buffer) - { - // Parse fields - while (buffer.hasRemaining()) + private void upgrade() { - byte b = buffer.get(); - if (_index < 6) + int proxyLineLength = _length - _buffer.remaining(); + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v1 pre-upgrade packet length (including CRLF) is {}", proxyLineLength); + if (proxyLineLength >= 110) { - if (b == ' ' || b == '\r') - { - _fields[_index++] = _builder.toString(); - _builder.setLength(0); - if (b == '\r') - _index = 6; - } - else if (b < ' ') - { - LOG.warn("Bad character {} for {}", b & 0xFF, getEndPoint()); - close(); - return false; - } - else - { - _builder.append((char)b); - } - } - else - { - if (b == '\n') - { - _index = 7; - return true; - } - - LOG.warn("Bad CRLF for {}", getEndPoint()); - close(); - return false; - } - } - return true; - } - - @Override - public void onFillable() - { - try - { - ByteBuffer buffer = null; - while (_index < 7) - { - // Create a buffer that will not read too much data - // since once read it is impossible to push back for the - // real connection to read it. - int size = Math.max(1, SIZE[_index] - _builder.length()); - if (buffer == null || buffer.capacity() != size) - buffer = BufferUtil.allocate(size); - else - BufferUtil.clear(buffer); - - // Read data - int fill = getEndPoint().fill(buffer); - if (fill < 0) - { - getEndPoint().shutdownOutput(); - return; - } - if (fill == 0) - { - fillInterested(); - return; - } - - _length += fill; - if (_length >= 108) - { - LOG.warn("PROXY line too long {} for {}", _length, getEndPoint()); - close(); - return; - } - - if (!parse(buffer)) - return; + LOG.warn("Proxy v1 PROXY line too long {} for {}", proxyLineLength, getEndPoint()); + releaseAndClose(); + return; } // Check proxy if (!"PROXY".equals(_fields[0])) { - LOG.warn("Not PROXY protocol for {}", getEndPoint()); - close(); + LOG.warn("Proxy v1 not PROXY protocol for {}", getEndPoint()); + releaseAndClose(); return; } @@ -311,184 +341,221 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory InetSocketAddress remote = new InetSocketAddress(srcIP, Integer.parseInt(srcPort)); InetSocketAddress local = new InetSocketAddress(dstIP, Integer.parseInt(dstPort)); - // Create the next protocol - ConnectionFactory connectionFactory = _connector.getConnectionFactory(_next); - if (connectionFactory == null) - { - LOG.warn("No next protocol '{}' for {}", _next, getEndPoint()); - close(); - return; - } - if (LOG.isDebugEnabled()) - LOG.warn("Next protocol '{}' for {} r={} l={}", _next, getEndPoint(), remote, local); + LOG.debug("Proxy v1 next protocol '{}' for {} r={} l={}", _next, getEndPoint(), remote, local); EndPoint endPoint = new ProxyEndPoint(getEndPoint(), remote, local); - Connection newConnection = connectionFactory.newConnection(_connector, endPoint); - endPoint.upgrade(newConnection); - } - catch (Throwable x) - { - LOG.warn("PROXY error for " + getEndPoint(), x); - close(); + upgradeToConnectionFactory(_next, _connector, endPoint); } } } - private enum Family + private static class ProxyV2ConnectionFactory extends AbstractConnectionFactory implements Detecting { - UNSPEC, INET, INET6, UNIX - } - - private enum Transport - { - UNSPEC, STREAM, DGRAM - } - - private static final byte[] MAGIC = new byte[]{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}; - - public class ProxyProtocolV2Connection extends AbstractConnection - { - private final Connector _connector; - private final String _next; - private final boolean _local; - private final Family _family; - private final Transport _transport; - private final int _length; - private final ByteBuffer _buffer; - - protected ProxyProtocolV2Connection(EndPoint endp, Connector connector, String next, ByteBuffer buffer) throws IOException + private enum Family { - super(endp, connector.getExecutor()); - _connector = connector; - _next = next; + UNSPEC, INET, INET6, UNIX + } - if (buffer.remaining() != 16) - throw new IllegalStateException(); + private enum Transport + { + UNSPEC, STREAM, DGRAM + } + + private static final byte[] SIGNATURE = new byte[] + { + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A + }; + private final String _nextProtocol; + private int _maxProxyHeader = 1024; + + private ProxyV2ConnectionFactory(String nextProtocol) + { + super("proxy"); + this._nextProtocol = nextProtocol; + } + + @Override + public Detection detect(ByteBuffer buffer) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 attempting detection with {} bytes", buffer.remaining()); + if (buffer.remaining() < SIGNATURE.length) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 detection requires more bytes"); + return Detection.NEED_MORE_BYTES; + } + + for (int i = 0; i < SIGNATURE.length; i++) + { + byte signatureByte = SIGNATURE[i]; + byte byteInBuffer = buffer.get(i); + if (byteInBuffer != signatureByte) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 detection unsuccessful"); + return Detection.NOT_RECOGNIZED; + } + } if (LOG.isDebugEnabled()) - LOG.debug("PROXYv2 header {} for {}", BufferUtil.toHexSummary(buffer), this); + LOG.debug("Proxy v2 detection succeeded"); + return Detection.RECOGNIZED; + } - // struct proxy_hdr_v2 { - // uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ - // uint8_t ver_cmd; /* protocol version and command */ - // uint8_t fam; /* protocol family and address */ - // uint16_t len; /* number of following bytes part of the header */ - // }; - for (byte magic : MAGIC) - { - if (buffer.get() != magic) - throw new IOException("Bad PROXY protocol v2 signature"); - } + public int getMaxProxyHeader() + { + return _maxProxyHeader; + } - int versionAndCommand = 0xff & buffer.get(); - if ((versionAndCommand & 0xf0) != 0x20) - throw new IOException("Bad PROXY protocol v2 version"); - _local = (versionAndCommand & 0xf) == 0x00; - - int transportAndFamily = 0xff & buffer.get(); - switch (transportAndFamily >> 4) - { - case 0: - _family = Family.UNSPEC; - break; - case 1: - _family = Family.INET; - break; - case 2: - _family = Family.INET6; - break; - case 3: - _family = Family.UNIX; - break; - default: - throw new IOException("Bad PROXY protocol v2 family"); - } - - switch (0xf & transportAndFamily) - { - case 0: - _transport = Transport.UNSPEC; - break; - case 1: - _transport = Transport.STREAM; - break; - case 2: - _transport = Transport.DGRAM; - break; - default: - throw new IOException("Bad PROXY protocol v2 family"); - } - - _length = buffer.getChar(); - - if (!_local && (_family == Family.UNSPEC || _family == Family.UNIX || _transport != Transport.STREAM)) - throw new IOException(String.format("Unsupported PROXY protocol v2 mode 0x%x,0x%x", versionAndCommand, transportAndFamily)); - - if (_length > getMaxProxyHeader()) - throw new IOException(String.format("Unsupported PROXY protocol v2 mode 0x%x,0x%x,0x%x", versionAndCommand, transportAndFamily, _length)); - - _buffer = _length > 0 ? BufferUtil.allocate(_length) : BufferUtil.EMPTY_BUFFER; + public void setMaxProxyHeader(int maxProxyHeader) + { + _maxProxyHeader = maxProxyHeader; } @Override - public void onOpen() + public Connection newConnection(Connector connector, EndPoint endp) { - super.onOpen(); - if (_buffer.remaining() == _length) - next(); - else - fillInterested(); + ConnectionFactory nextConnectionFactory = findNextConnectionFactory(_nextProtocol, connector, getProtocol(), endp); + return configure(new ProxyProtocolV2Connection(endp, connector, nextConnectionFactory), connector, endp); } - @Override - public void onFillable() + private class ProxyProtocolV2Connection extends AbstractConnection implements Connection.UpgradeFrom, Connection.UpgradeTo { - try + private static final int HEADER_LENGTH = 16; + + private final Connector _connector; + private final ConnectionFactory _next; + private final ByteBuffer _buffer; + private boolean _local; + private Family _family; + private int _length; + private boolean _headerParsed; + + protected ProxyProtocolV2Connection(EndPoint endp, Connector connector, ConnectionFactory next) { - while (_buffer.remaining() < _length) + super(endp, connector.getExecutor()); + _connector = connector; + _next = next; + _buffer = _connector.getByteBufferPool().acquire(getInputBufferSize(), true); + } + + @Override + public void onUpgradeTo(ByteBuffer prefilled) + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 copying prefilled buffer {}", BufferUtil.toDetailString(prefilled)); + if (BufferUtil.hasContent(prefilled)) + BufferUtil.append(_buffer, prefilled); + } + + @Override + public void onOpen() + { + super.onOpen(); + + try { - // Read data - int fill = getEndPoint().fill(_buffer); - if (fill < 0) + parseHeader(); + if (_headerParsed && _buffer.remaining() >= _length) { - getEndPoint().shutdownOutput(); - return; + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 onOpen parsing fixed length packet part done, now upgrading"); + parseBodyAndUpgrade(); } - if (fill == 0) + else { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 onOpen parsing fixed length packet ran out of bytes, marking as fillInterested"); fillInterested(); - return; } } - next(); - } - catch (Throwable x) - { - LOG.warn("PROXY error for " + getEndPoint(), x); - close(); - } - } - - private void next() - { - if (LOG.isDebugEnabled()) - LOG.debug("PROXYv2 next {} from {} for {}", _next, BufferUtil.toHexSummary(_buffer), this); - - // Create the next protocol - ConnectionFactory connectionFactory = _connector.getConnectionFactory(_next); - if (connectionFactory == null) - { - LOG.info("Next protocol '{}' for {}", _next, getEndPoint()); - close(); - return; + catch (Exception x) + { + LOG.warn("Proxy v2 error for {}", getEndPoint(), x); + releaseAndClose(); + } } - // Do we need to wrap the endpoint? - EndPoint endPoint = getEndPoint(); - if (!_local) + @Override + public void onFillable() { try + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 onFillable header parsed? ", _headerParsed); + while (!_headerParsed) + { + // Read data + int fill = getEndPoint().fill(_buffer); + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 filled buffer with {} bytes", fill); + if (fill < 0) + { + _connector.getByteBufferPool().release(_buffer); + getEndPoint().shutdownOutput(); + return; + } + if (fill == 0) + { + fillInterested(); + return; + } + + parseHeader(); + } + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 onFillable header parsed, length = {}, buffer = {}", _length, BufferUtil.toDetailString(_buffer)); + + while (_buffer.remaining() < _length) + { + // Read data + int fill = getEndPoint().fill(_buffer); + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 filled buffer with {} bytes", fill); + if (fill < 0) + { + _connector.getByteBufferPool().release(_buffer); + getEndPoint().shutdownOutput(); + return; + } + if (fill == 0) + { + fillInterested(); + return; + } + } + + parseBodyAndUpgrade(); + } + catch (Throwable x) + { + LOG.warn("Proxy v2 error for " + getEndPoint(), x); + releaseAndClose(); + } + } + + @Override + public ByteBuffer onUpgradeFrom() + { + return _buffer; + } + + private void parseBodyAndUpgrade() throws IOException + { + // stop reading when bufferRemainingReserve bytes are remaining in the buffer + int nonProxyRemaining = _buffer.remaining() - _length; + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 parsing body, length = {}, buffer = {}", _length, BufferUtil.toHexSummary(_buffer)); + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 body {} from {} for {}", _next, BufferUtil.toHexSummary(_buffer), this); + + // Do we need to wrap the endpoint? + EndPoint endPoint = getEndPoint(); + if (!_local) { InetAddress src; InetAddress dst; @@ -532,15 +599,15 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory endPoint = proxyEndPoint; // Any additional info? - while (_buffer.hasRemaining()) + while (_buffer.remaining() > nonProxyRemaining) { int type = 0xff & _buffer.get(); - int length = _buffer.getShort(); + int length = _buffer.getChar(); byte[] value = new byte[length]; _buffer.get(value); if (LOG.isDebugEnabled()) - LOG.debug(String.format("T=%x L=%d V=%s for %s", type, length, TypeUtil.toHexString(value), this)); + LOG.debug(String.format("Proxy v2 T=%x L=%d V=%s for %s", type, length, TypeUtil.toHexString(value), this)); switch (type) { @@ -593,16 +660,94 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory } if (LOG.isDebugEnabled()) - LOG.debug("{} {}", getEndPoint(), proxyEndPoint.toString()); - } - catch (Exception e) - { - LOG.warn(e); + LOG.debug("Proxy v2 {} {}", getEndPoint(), proxyEndPoint.toString()); } + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 parsing dynamic packet part is now done, upgrading to {}", _nextProtocol); + upgradeToConnectionFactory(_next, _connector, endPoint); } - Connection newConnection = connectionFactory.newConnection(_connector, endPoint); - endPoint.upgrade(newConnection); + private void parseHeader() throws IOException + { + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 parsing fixed length packet part, buffer = {}", BufferUtil.toDetailString(_buffer)); + if (_buffer.remaining() < HEADER_LENGTH) + return; + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 header {} for {}", BufferUtil.toHexSummary(_buffer), this); + + // struct proxy_hdr_v2 { + // uint8_t sig[12]; /* hex 0D 0A 0D 0A 00 0D 0A 51 55 49 54 0A */ + // uint8_t ver_cmd; /* protocol version and command */ + // uint8_t fam; /* protocol family and address */ + // uint16_t len; /* number of following bytes part of the header */ + // }; + for (byte signatureByte : SIGNATURE) + { + if (_buffer.get() != signatureByte) + throw new IOException("Proxy v2 bad PROXY signature"); + } + + int versionAndCommand = 0xFF & _buffer.get(); + if ((versionAndCommand & 0xF0) != 0x20) + throw new IOException("Proxy v2 bad PROXY version"); + _local = (versionAndCommand & 0xF) == 0x00; + + int transportAndFamily = 0xFF & _buffer.get(); + switch (transportAndFamily >> 4) + { + case 0: + _family = Family.UNSPEC; + break; + case 1: + _family = Family.INET; + break; + case 2: + _family = Family.INET6; + break; + case 3: + _family = Family.UNIX; + break; + default: + throw new IOException("Proxy v2 bad PROXY family"); + } + + Transport transport; + switch (0xF & transportAndFamily) + { + case 0: + transport = Transport.UNSPEC; + break; + case 1: + transport = Transport.STREAM; + break; + case 2: + transport = Transport.DGRAM; + break; + default: + throw new IOException("Proxy v2 bad PROXY family"); + } + + _length = _buffer.getChar(); + + if (!_local && (_family == Family.UNSPEC || _family == Family.UNIX || transport != Transport.STREAM)) + throw new IOException(String.format("Proxy v2 unsupported PROXY mode 0x%x,0x%x", versionAndCommand, transportAndFamily)); + + if (_length > getMaxProxyHeader()) + throw new IOException(String.format("Proxy v2 Unsupported PROXY mode 0x%x,0x%x,0x%x", versionAndCommand, transportAndFamily, _length)); + + if (LOG.isDebugEnabled()) + LOG.debug("Proxy v2 fixed length packet part is now parsed"); + _headerParsed = true; + } + + private void releaseAndClose() + { + _connector.getByteBufferPool().release(_buffer); + close(); + } } } @@ -619,59 +764,6 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory _local = local; } - @Override - public InetSocketAddress getLocalAddress() - { - return _local; - } - - @Override - public InetSocketAddress getRemoteAddress() - { - return _remote; - } - - @Override - public String toString() - { - return String.format("%s@%x[remote=%s,local=%s,endpoint=%s]", - getClass().getSimpleName(), - hashCode(), - _remote, - _local, - _endp); - } - - @Override - public boolean isOpen() - { - return _endp.isOpen(); - } - - @Override - public long getCreatedTimeStamp() - { - return _endp.getCreatedTimeStamp(); - } - - @Override - public void shutdownOutput() - { - _endp.shutdownOutput(); - } - - @Override - public boolean isOutputShutdown() - { - return _endp.isOutputShutdown(); - } - - @Override - public boolean isInputShutdown() - { - return _endp.isInputShutdown(); - } - @Override public void close(Throwable cause) { @@ -684,30 +776,6 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory return _endp.fill(buffer); } - @Override - public boolean flush(ByteBuffer... buffer) throws IOException - { - return _endp.flush(buffer); - } - - @Override - public Object getTransport() - { - return _endp.getTransport(); - } - - @Override - public long getIdleTimeout() - { - return _endp.getIdleTimeout(); - } - - @Override - public void setIdleTimeout(long idleTimeout) - { - _endp.setIdleTimeout(idleTimeout); - } - @Override public void fillInterested(Callback callback) throws ReadPendingException { @@ -715,21 +783,9 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory } @Override - public boolean tryFillInterested(Callback callback) + public boolean flush(ByteBuffer... buffer) throws IOException { - return _endp.tryFillInterested(callback); - } - - @Override - public boolean isFillInterested() - { - return _endp.isFillInterested(); - } - - @Override - public void write(Callback callback, ByteBuffer... buffers) throws WritePendingException - { - _endp.write(callback, buffers); + return _endp.flush(buffer); } @Override @@ -745,9 +801,63 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory } @Override - public void onOpen() + public long getCreatedTimeStamp() { - _endp.onOpen(); + return _endp.getCreatedTimeStamp(); + } + + @Override + public long getIdleTimeout() + { + return _endp.getIdleTimeout(); + } + + @Override + public void setIdleTimeout(long idleTimeout) + { + _endp.setIdleTimeout(idleTimeout); + } + + @Override + public InetSocketAddress getLocalAddress() + { + return _local; + } + + @Override + public InetSocketAddress getRemoteAddress() + { + return _remote; + } + + @Override + public Object getTransport() + { + return _endp.getTransport(); + } + + @Override + public boolean isFillInterested() + { + return _endp.isFillInterested(); + } + + @Override + public boolean isInputShutdown() + { + return _endp.isInputShutdown(); + } + + @Override + public boolean isOpen() + { + return _endp.isOpen(); + } + + @Override + public boolean isOutputShutdown() + { + return _endp.isOutputShutdown(); } @Override @@ -756,10 +866,45 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory _endp.onClose(cause); } + @Override + public void onOpen() + { + _endp.onOpen(); + } + + @Override + public void shutdownOutput() + { + _endp.shutdownOutput(); + } + + @Override + public String toString() + { + return String.format("%s@%x[remote=%s,local=%s,endpoint=%s]", + getClass().getSimpleName(), + hashCode(), + _remote, + _local, + _endp); + } + + @Override + public boolean tryFillInterested(Callback callback) + { + return _endp.tryFillInterested(callback); + } + @Override public void upgrade(Connection newConnection) { _endp.upgrade(newConnection); } + + @Override + public void write(Callback callback, ByteBuffer... buffers) throws WritePendingException + { + _endp.write(callback, buffers); + } } } diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/SslConnectionFactory.java b/jetty-server/src/main/java/org/eclipse/jetty/server/SslConnectionFactory.java index 8cf2e0b0225..d51d1dcce97 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/SslConnectionFactory.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/SslConnectionFactory.java @@ -18,6 +18,7 @@ package org.eclipse.jetty.server; +import java.nio.ByteBuffer; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLSession; @@ -31,8 +32,12 @@ import org.eclipse.jetty.util.annotation.Name; import org.eclipse.jetty.util.component.ContainerLifeCycle; import org.eclipse.jetty.util.ssl.SslContextFactory; -public class SslConnectionFactory extends AbstractConnectionFactory +public class SslConnectionFactory extends AbstractConnectionFactory implements ConnectionFactory.Detecting { + private static final int TLS_ALERT_FRAME_TYPE = 0x15; + private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16; + private static final int TLS_MAJOR_VERSION = 3; + private final SslContextFactory.Server _sslContextFactory; private final String _nextProtocol; private boolean _directBuffersForEncryption = false; @@ -99,6 +104,17 @@ public class SslConnectionFactory extends AbstractConnectionFactory setInputBufferSize(session.getPacketBufferSize()); } + @Override + public Detection detect(ByteBuffer buffer) + { + if (buffer.remaining() < 2) + return Detection.NEED_MORE_BYTES; + int tlsFrameType = buffer.get(0) & 0xFF; + int tlsMajorVersion = buffer.get(1) & 0xFF; + boolean seemsSsl = (tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE) && tlsMajorVersion == TLS_MAJOR_VERSION; + return seemsSsl ? Detection.RECOGNIZED : Detection.NOT_RECOGNIZED; + } + @Override public Connection newConnection(Connector connector, EndPoint endPoint) { diff --git a/jetty-server/src/test/java/org/eclipse/jetty/server/DetectorConnectionTest.java b/jetty-server/src/test/java/org/eclipse/jetty/server/DetectorConnectionTest.java new file mode 100644 index 00000000000..79bb28fecc2 --- /dev/null +++ b/jetty-server/src/test/java/org/eclipse/jetty/server/DetectorConnectionTest.java @@ -0,0 +1,710 @@ +// +// ======================================================================== +// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others. +// +// This program and the accompanying materials are made available under +// the terms of the Eclipse Public License 2.0 which is available at +// https://www.eclipse.org/legal/epl-2.0 +// +// This Source Code may also be made available under the following +// Secondary Licenses when the conditions for such availability set +// forth in the Eclipse Public License, v. 2.0 are satisfied: +// the Apache License v2.0 which is available at +// https://www.apache.org/licenses/LICENSE-2.0 +// +// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0 +// ======================================================================== +// + +package org.eclipse.jetty.server; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.Socket; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.net.ssl.SSLSocketFactory; + +import org.eclipse.jetty.http.HttpVersion; +import org.eclipse.jetty.io.AbstractConnection; +import org.eclipse.jetty.io.Connection; +import org.eclipse.jetty.io.EndPoint; +import org.eclipse.jetty.toolchain.test.MavenTestingUtils; +import org.eclipse.jetty.util.Callback; +import org.eclipse.jetty.util.TypeUtil; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DetectorConnectionTest +{ + private Server _server; + + private static String inputStreamToString(InputStream is) throws IOException + { + StringBuilder sb = new StringBuilder(); + BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.US_ASCII)); + + while (true) + { + String line = reader.readLine(); + if (line == null) + { + // remove the last '\n' + if (sb.length() != 0) + sb.deleteCharAt(sb.length() - 1); + break; + } + sb.append(line).append('\n'); + } + + return sb.length() == 0 ? null : sb.toString(); + } + + private String getResponse(String request) throws Exception + { + return getResponse(request.getBytes(StandardCharsets.US_ASCII)); + } + + private String getResponse(byte[]... requests) throws Exception + { + try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort())) + { + for (byte[] request : requests) + { + socket.getOutputStream().write(request); + } + return inputStreamToString(socket.getInputStream()); + } + } + + private String getResponseOverSsl(String request) throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + sslContextFactory.start(); + + SSLSocketFactory socketFactory = sslContextFactory.getSslContext().getSocketFactory(); + try (Socket socket = socketFactory.createSocket(_server.getURI().getHost(), _server.getURI().getPort())) + { + socket.getOutputStream().write(request.getBytes(StandardCharsets.US_ASCII)); + return inputStreamToString(socket.getInputStream()); + } + finally + { + sslContextFactory.stop(); + } + } + + private void start(ConnectionFactory... connectionFactories) throws Exception + { + _server = new Server(); + _server.addConnector(new ServerConnector(_server, 1, 1, connectionFactories)); + _server.setHandler(new DumpHandler()); + _server.start(); + } + + @AfterEach + public void destroy() throws Exception + { + if (_server != null) + _server.stop(); + } + + @Test + public void testConnectionClosedDuringDetection() throws Exception + { + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy); + + start(detector, http); + + try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort())) + { + socket.getOutputStream().write("PR".getBytes(StandardCharsets.US_ASCII)); + Thread.sleep(100); // make sure the onFillable callback gets called + socket.getOutputStream().write("OX".getBytes(StandardCharsets.US_ASCII)); + socket.getOutputStream().close(); + + assertThrows(SocketException.class, () -> socket.getInputStream().read()); + } + } + + @Test + public void testConnectionClosedDuringProxyV1Handling() throws Exception + { + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy); + + start(detector, http); + + try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort())) + { + socket.getOutputStream().write("PROXY".getBytes(StandardCharsets.US_ASCII)); + Thread.sleep(100); // make sure the onFillable callback gets called + socket.getOutputStream().write(" ".getBytes(StandardCharsets.US_ASCII)); + socket.getOutputStream().close(); + + assertThrows(SocketException.class, () -> socket.getInputStream().read()); + } + } + + @Test + public void testConnectionClosedDuringProxyV2HandlingFixedLengthPart() throws Exception + { + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy); + + start(detector, http); + + try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort())) + { + socket.getOutputStream().write(TypeUtil.fromHexString("0D0A0D0A000D0A515549540A")); // proxy V2 Preamble + Thread.sleep(100); // make sure the onFillable callback gets called + socket.getOutputStream().write(TypeUtil.fromHexString("21")); // V2, PROXY + socket.getOutputStream().close(); + + assertThrows(SocketException.class, () -> socket.getInputStream().read()); + } + } + + @Test + public void testConnectionClosedDuringProxyV2HandlingDynamicLengthPart() throws Exception + { + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy); + + start(detector, http); + + try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort())) + { + socket.getOutputStream().write(TypeUtil.fromHexString( + // proxy V2 Preamble + "0D0A0D0A000D0A515549540A" + + // V2, PROXY + "21" + + // 0x1 : AF_INET 0x1 : STREAM. + "11" + + // Address length is 2*4 + 2*2 = 12 bytes. + // length of remaining header (4+4+2+2 = 12) + "000C" + )); + Thread.sleep(100); // make sure the onFillable callback gets called + socket.getOutputStream().write(TypeUtil.fromHexString( + // uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port; + "C0A80001" // 8080 + )); + socket.getOutputStream().close(); + + assertThrows(SocketException.class, () -> socket.getInputStream().read()); + } + } + + @Test + public void testDetectingSslProxyToHttpNoSslWithProxy() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy); + + start(detector, http); + + String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" + + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(request); + + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + assertThat(response, Matchers.containsString("pathInfo=/path")); + assertThat(response, Matchers.containsString("local=5.6.7.8:222")); + assertThat(response, Matchers.containsString("remote=1.2.3.4:111")); + } + + @Test + public void testDetectingSslProxyToHttpWithSslNoProxy() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy); + + start(detector, http); + + String request = "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponseOverSsl(request); + + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + } + + @Test + public void testDetectingSslProxyToHttpWithSslWithProxy() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy); + + start(detector, http); + + String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" + + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponseOverSsl(request); + + // SSL matched, so the upgrade was made to HTTP which does not understand the proxy request + assertThat(response, Matchers.containsString("HTTP/1.1 400")); + } + + @Test + public void testDetectionUnsuccessfulUpgradesToNextProtocol() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy); + + start(detector, http); + + String request = "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(request); + + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + } + + @Test + void testDetectorToNextDetector() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + DetectorConnectionFactory proxyDetector = new DetectorConnectionFactory(proxy); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, proxyDetector.getProtocol()); + DetectorConnectionFactory sslDetector = new DetectorConnectionFactory(ssl); + + start(sslDetector, proxyDetector, http); + + String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" + + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponseOverSsl(request); + + // SSL matched, so the upgrade was made to proxy which itself upgraded to HTTP + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + assertThat(response, Matchers.containsString("pathInfo=/path")); + assertThat(response, Matchers.containsString("local=5.6.7.8:222")); + assertThat(response, Matchers.containsString("remote=1.2.3.4:111")); + } + + @Test + void testDetectorWithDetectionUnsuccessful() throws Exception + { + AtomicBoolean detectionSuccessful = new AtomicBoolean(true); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(HttpVersion.HTTP_1_1.asString()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy) + { + @Override + protected void nextProtocol(Connector connector, EndPoint endPoint, ByteBuffer buffer) + { + if (!detectionSuccessful.compareAndSet(true, false)) + throw new AssertionError("DetectionUnsuccessful callback should only have been called once"); + + // omitting this will leak the buffer + connector.getByteBufferPool().release(buffer); + + Callback.Completable completable = new Callback.Completable(); + endPoint.write(completable, ByteBuffer.wrap("No upgrade for you".getBytes(StandardCharsets.US_ASCII))); + completable.whenComplete((r, x) -> endPoint.close()); + } + }; + HttpConnectionFactory http = new HttpConnectionFactory(); + + start(detector, http); + + String request = "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(request); + + assertEquals("No upgrade for you", response); + assertFalse(detectionSuccessful.get()); + } + + @Test + void testDetectorWithProxyThatHasNoNextProto() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy); + + start(detector, http); + + String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" + + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(request); + + // ProxyConnectionFactory has no next protocol -> it cannot upgrade + assertThat(response, Matchers.nullValue()); + } + + @Test + void testOptionalSsl() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConnectionFactory http = new HttpConnectionFactory(); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl); + + start(detector, http); + + String request = + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String clearTextResponse = getResponse(request); + String sslResponse = getResponseOverSsl(request); + + // both clear text and SSL can be responded to just fine + assertThat(clearTextResponse, Matchers.containsString("HTTP/1.1 200")); + assertThat(sslResponse, Matchers.containsString("HTTP/1.1 200")); + } + + @Test + void testDetectorThatHasNoConfiguredNextProto() throws Exception + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, HttpVersion.HTTP_1_1.asString()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl); + + start(detector); + + String request = + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(request); + + assertThat(response, Matchers.nullValue()); + } + + @Test + void testDetectorWithNextProtocolThatDoesNotExist() throws Exception + { + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory("does-not-exist"); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy); + + start(detector, http); + + String proxyReq = + // proxy V2 Preamble + "0D0A0D0A000D0A515549540A" + + // V2, PROXY + "21" + + // 0x1 : AF_INET 0x1 : STREAM. + "11" + + // Address length is 2*4 + 2*2 = 12 bytes. + // length of remaining header (4+4+2+2 = 12) + "000C" + + // uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port; + "C0A80001" + // 192.168.0.1 + "7f000001" + // 127.0.0.1 + "3039" + // 12345 + "1F90"; // 8080 + + String httpReq = + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(TypeUtil.fromHexString(proxyReq), httpReq.getBytes(StandardCharsets.US_ASCII)); + + assertThat(response, Matchers.nullValue()); + } + + @Test + void testDetectingWithNextProtocolThatDoesNotImplementUpgradeTo() throws Exception + { + ConnectionFactory.Detecting noUpgradeTo = new ConnectionFactory.Detecting() + { + @Override + public Detection detect(ByteBuffer buffer) + { + return Detection.RECOGNIZED; + } + + @Override + public String getProtocol() + { + return "noUpgradeTo"; + } + + @Override + public List getProtocols() + { + return Collections.singletonList(getProtocol()); + } + + @Override + public Connection newConnection(Connector connector, EndPoint endPoint) + { + return new AbstractConnection(null, connector.getExecutor()) + { + @Override + public void onFillable() + { + } + }; + } + }; + + HttpConnectionFactory http = new HttpConnectionFactory(); + DetectorConnectionFactory detector = new DetectorConnectionFactory(noUpgradeTo); + + start(detector, http); + + String proxyReq = + // proxy V2 Preamble + "0D0A0D0A000D0A515549540A" + + // V2, PROXY + "21" + + // 0x1 : AF_INET 0x1 : STREAM. + "11" + + // Address length is 2*4 + 2*2 = 12 bytes. + // length of remaining header (4+4+2+2 = 12) + "000C" + + // uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port; + "C0A80001" + // 192.168.0.1 + "7f000001" + // 127.0.0.1 + "3039" + // 12345 + "1F90"; // 8080 + + String httpReq = + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(TypeUtil.fromHexString(proxyReq), httpReq.getBytes(StandardCharsets.US_ASCII)); + + assertThat(response, Matchers.nullValue()); + } + + @Test + void testDetectorWithNextProtocolThatDoesNotImplementUpgradeTo() throws Exception + { + ConnectionFactory noUpgradeTo = new ConnectionFactory() + { + @Override + public String getProtocol() + { + return "noUpgradeTo"; + } + + @Override + public List getProtocols() + { + return Collections.singletonList(getProtocol()); + } + + @Override + public Connection newConnection(Connector connector, EndPoint endPoint) + { + return new AbstractConnection(null, connector.getExecutor()) + { + @Override + public void onFillable() + { + } + }; + } + }; + + HttpConnectionFactory http = new HttpConnectionFactory(); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol()); + DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy); + + start(detector, noUpgradeTo); + + String request = + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = getResponse(request); + + assertThat(response, Matchers.nullValue()); + } + + @Test + void testGeneratedProtocolNames() + { + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + ProxyConnectionFactory proxy = new ProxyConnectionFactory(HttpVersion.HTTP_1_1.asString()); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, HttpVersion.HTTP_1_1.asString()); + + assertEquals("[SSL|[proxy]]", new DetectorConnectionFactory(ssl, proxy).getProtocol()); + assertEquals("[[proxy]|SSL]", new DetectorConnectionFactory(proxy, ssl).getProtocol()); + } + + @Test + void testDetectorWithNoDetectingFails() + { + assertThrows(IllegalArgumentException.class, DetectorConnectionFactory::new); + } + + @Test + void testExerciseDetectorNotEnoughBytes() throws Exception + { + ConnectionFactory.Detecting detectingNeverRecognizes = new ConnectionFactory.Detecting() + { + @Override + public Detection detect(ByteBuffer buffer) + { + return Detection.NOT_RECOGNIZED; + } + + @Override + public String getProtocol() + { + return "nevergood"; + } + + @Override + public List getProtocols() + { + throw new AssertionError(); + } + + @Override + public Connection newConnection(Connector connector, EndPoint endPoint) + { + throw new AssertionError(); + } + }; + + ConnectionFactory.Detecting detectingAlwaysNeedMoreBytes = new ConnectionFactory.Detecting() + { + @Override + public Detection detect(ByteBuffer buffer) + { + return Detection.NEED_MORE_BYTES; + } + + @Override + public String getProtocol() + { + return "neverenough"; + } + + @Override + public List getProtocols() + { + throw new AssertionError(); + } + + @Override + public Connection newConnection(Connector connector, EndPoint endPoint) + { + throw new AssertionError(); + } + }; + + DetectorConnectionFactory detector = new DetectorConnectionFactory(detectingNeverRecognizes, detectingAlwaysNeedMoreBytes); + HttpConnectionFactory http = new HttpConnectionFactory(); + + start(detector, http); + + String request = "AAAA".repeat(32768); + String response = getResponse(request); + + assertThat(response, Matchers.nullValue()); + } +} diff --git a/jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java b/jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java index 04cd2f680c9..6f2d0cf0402 100644 --- a/jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java +++ b/jetty-server/src/test/java/org/eclipse/jetty/server/OptionalSslConnectionTest.java @@ -38,6 +38,7 @@ import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; public class OptionalSslConnectionTest { @@ -59,7 +60,7 @@ public class OptionalSslConnectionTest HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); OptionalSslConnectionFactory sslOrOther = configFn.apply(ssl); - connector = new ServerConnector(server, 1, 1, sslOrOther, ssl, http); + connector = new ServerConnector(server, 1, 1, sslOrOther, http); server.addConnector(connector); server.setHandler(handler); @@ -203,6 +204,47 @@ public class OptionalSslConnectionTest } } + @Test + void testNextProtocolIsNotNullButNotConfiguredEither() throws Exception + { + QueuedThreadPool serverThreads = new QueuedThreadPool(); + serverThreads.setName("server"); + server = new Server(serverThreads); + + String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath(); + SslContextFactory.Server sslContextFactory = new SslContextFactory.Server(); + sslContextFactory.setKeyStorePath(keystore); + sslContextFactory.setKeyStorePassword("storepwd"); + sslContextFactory.setKeyManagerPassword("keypwd"); + + HttpConfiguration httpConfig = new HttpConfiguration(); + HttpConnectionFactory http = new HttpConnectionFactory(httpConfig); + SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol()); + OptionalSslConnectionFactory optSsl = new OptionalSslConnectionFactory(ssl, "no-such-protocol"); + connector = new ServerConnector(server, 1, 1, optSsl, http); + server.addConnector(connector); + server.setHandler(new EmptyServerHandler()); + server.start(); + + try (Socket socket = new Socket(server.getURI().getHost(), server.getURI().getPort())) + { + OutputStream sslOutput = socket.getOutputStream(); + String request = + "GET / HTTP/1.1\r\n" + + "Host: localhost\r\n" + + "\r\n"; + byte[] requestBytes = request.getBytes(StandardCharsets.US_ASCII); + + sslOutput.write(requestBytes); + sslOutput.flush(); + + socket.setSoTimeout(5000); + InputStream sslInput = socket.getInputStream(); + HttpTester.Response response = HttpTester.parseResponse(sslInput); + assertNull(response); + } + } + private static class EmptyServerHandler extends AbstractHandler { @Override diff --git a/jetty-server/src/test/java/org/eclipse/jetty/server/ProxyConnectionTest.java b/jetty-server/src/test/java/org/eclipse/jetty/server/ProxyConnectionTest.java index 6c8f8cdf09f..c0ce3392953 100644 --- a/jetty-server/src/test/java/org/eclipse/jetty/server/ProxyConnectionTest.java +++ b/jetty-server/src/test/java/org/eclipse/jetty/server/ProxyConnectionTest.java @@ -18,75 +18,94 @@ package org.eclipse.jetty.server; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import org.eclipse.jetty.http.HttpVersion; import org.eclipse.jetty.server.handler.ErrorHandler; import org.eclipse.jetty.toolchain.test.Net; +import org.eclipse.jetty.util.TypeUtil; import org.eclipse.jetty.util.log.StacklessLogging; import org.hamcrest.Matchers; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assumptions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertNull; public class ProxyConnectionTest { - private Server _server; - private LocalConnector _connector; - - @BeforeEach - public void init() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + public void testBadCRLF(RequestProcessor p) throws Exception { - _server = new Server(); - - HttpConnectionFactory http = new HttpConnectionFactory(); - http.getHttpConfiguration().setRequestHeaderSize(1024); - http.getHttpConfiguration().setResponseHeaderSize(1024); - - ProxyConnectionFactory proxy = new ProxyConnectionFactory(); - - _connector = new LocalConnector(_server, null, null, null, 1, proxy, http); - _connector.setIdleTimeout(1000); - _server.addConnector(_connector); - _server.setHandler(new DumpHandler()); - ErrorHandler eh = new ErrorHandler(); - eh.setServer(_server); - _server.addBean(eh); - _server.start(); - } - - @AfterEach - public void destroy() throws Exception - { - _server.stop(); - _server.join(); - } - - @Test - public void testSimple() throws Exception - { - String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" + + String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r \n" + "GET /path HTTP/1.1\n" + "Host: server:80\n" + "Connection: close\n" + - "\n"); - - assertThat(response, Matchers.containsString("HTTP/1.1 200")); - assertThat(response, Matchers.containsString("pathInfo=/path")); - assertThat(response, Matchers.containsString("local=5.6.7.8:222")); - assertThat(response, Matchers.containsString("remote=1.2.3.4:111")); + "\n"; + String response = p.sendRequestWaitingForResponse(request); + assertNull(response); } - @Test - public void testIPv6() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + public void testBadChar(RequestProcessor p) throws Exception + { + String request = "PROXY\tTCP 1.2.3.4 5.6.7.8 111 222\r\n" + + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = p.sendRequestWaitingForResponse(request); + assertNull(response); + } + + @ParameterizedTest + @MethodSource("requestProcessors") + public void testBadPort(RequestProcessor p) throws Exception + { + try (StacklessLogging stackless = new StacklessLogging(ProxyConnectionFactory.class)) + { + String request = "PROXY TCP 1.2.3.4 5.6.7.8 9999999999999 222\r\n" + + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = p.sendRequestWaitingForResponse(request); + assertNull(response); + } + } + + @ParameterizedTest + @MethodSource("requestProcessors") + public void testHttp(RequestProcessor p) throws Exception + { + String request = + "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + String response = p.sendRequestWaitingForResponse(request); + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + } + + @ParameterizedTest + @MethodSource("requestProcessors") + public void testIPv6(RequestProcessor p) throws Exception { Assumptions.assumeTrue(Net.isIpv6InterfaceAvailable()); - String response = _connector.getResponse("PROXY TCP6 eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n" + + String request = "PROXY TCP6 eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n" + "GET /path HTTP/1.1\n" + "Host: server:80\n" + "Connection: close\n" + - "\n"); + "\n"; + + String response = p.sendRequestWaitingForResponse(request); assertThat(response, Matchers.containsString("HTTP/1.1 200")); assertThat(response, Matchers.containsString("pathInfo=/path")); @@ -94,83 +113,272 @@ public class ProxyConnectionTest assertThat(response, Matchers.containsString("local=ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff:65535")); } - @Test - public void testTooLong() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + public void testIPv6V2(RequestProcessor p) throws Exception { - String response = _connector.getResponse("PROXY TOOLONG!!! eeee:eeee:eeee:eeee:0000:0000:0000:0000 ffff:ffff:ffff:ffff:0000:0000:0000:0000 65535 65535\r\n" + + Assumptions.assumeTrue(Net.isIpv6InterfaceAvailable()); + + String proxy = + // Preamble + "0D0A0D0A000D0A515549540A" + + + // V2, PROXY + "21" + + + // 0x1 : AF_INET6 0x1 : STREAM. + "21" + + + // Address length is 2*16 + 2*2 = 36 bytes. + // length of remaining header (16+16+2+2 = 36) + "0024" + + + // uint8_t src_addr[16]; uint8_t dst_addr[16]; uint16_t src_port; uint16_t dst_port; + "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" + // ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff + "EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE" + // eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee + "3039" + // 12345 + "1F90"; // 8080 + String http = "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + + String response = p.sendRequestWaitingForResponse(TypeUtil.fromHexString(proxy), http.getBytes(StandardCharsets.US_ASCII)); + + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + assertThat(response, Matchers.containsString("pathInfo=/path")); + assertThat(response, Matchers.containsString("local=eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee:8080")); + assertThat(response, Matchers.containsString("remote=ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff:12345")); + } + + @ParameterizedTest + @MethodSource("requestProcessors") + public void testMissingField(RequestProcessor p) throws Exception + { + String request = "PROXY TCP 1.2.3.4 5.6.7.8 222\r\n" + "GET /path HTTP/1.1\n" + "Host: server:80\n" + "Connection: close\n" + - "\n"); - + "\n"; + String response = p.sendRequestWaitingForResponse(request); assertNull(response); } - @Test - public void testNotComplete() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + public void testNotComplete(RequestProcessor p) throws Exception { - _connector.setIdleTimeout(100); - String response = _connector.getResponse("PROXY TIMEOUT"); + String response = p.customize(connector -> connector.setIdleTimeout(100)).sendRequestWaitingForResponse("PROXY TIMEOUT"); assertNull(response); } - @Test - public void testBadChar() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + public void testTooLong(RequestProcessor p) throws Exception { - String response = _connector.getResponse("PROXY\tTCP 1.2.3.4 5.6.7.8 111 222\r\n" + + String request = "PROXY TOOLONG!!! eeee:eeee:eeee:eeee:0000:0000:0000:0000 ffff:ffff:ffff:ffff:0000:0000:0000:0000 65535 65535\r\n" + "GET /path HTTP/1.1\n" + "Host: server:80\n" + "Connection: close\n" + - "\n"); + "\n"; + + String response = p.sendRequestWaitingForResponse(request); + assertNull(response); } - @Test - public void testBadCRLF() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + void testSimple(RequestProcessor p) throws Exception { - String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 111 222\r \n" + + String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" + "GET /path HTTP/1.1\n" + "Host: server:80\n" + "Connection: close\n" + - "\n"); - assertNull(response); + "\n"; + + String response = p.sendRequestWaitingForResponse(request); + + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + assertThat(response, Matchers.containsString("pathInfo=/path")); + assertThat(response, Matchers.containsString("local=5.6.7.8:222")); + assertThat(response, Matchers.containsString("remote=1.2.3.4:111")); } - @Test - public void testBadPort() throws Exception + @ParameterizedTest + @MethodSource("requestProcessors") + void testSimpleV2(RequestProcessor p) throws Exception { - try (StacklessLogging stackless = new StacklessLogging(ProxyConnectionFactory.class)) + String proxy = + // Preamble + "0D0A0D0A000D0A515549540A" + + + // V2, PROXY + "21" + + + // 0x1 : AF_INET 0x1 : STREAM. + "11" + + + // Address length is 2*4 + 2*2 = 12 bytes. + // length of remaining header (4+4+2+2 = 12) + "000C" + + + // uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port; + "C0A80001" + // 192.168.0.1 + "7f000001" + // 127.0.0.1 + "3039" + // 12345 + "1F90"; // 8080 + String http = "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + + String response = p.sendRequestWaitingForResponse(TypeUtil.fromHexString(proxy), http.getBytes(StandardCharsets.US_ASCII)); + + assertThat(response, Matchers.containsString("HTTP/1.1 200")); + assertThat(response, Matchers.containsString("pathInfo=/path")); + assertThat(response, Matchers.containsString("local=127.0.0.1:8080")); + assertThat(response, Matchers.containsString("remote=192.168.0.1:12345")); + } + + @ParameterizedTest + @MethodSource("requestProcessors") + void testMaxHeaderLengthV2(RequestProcessor p) throws Exception + { + p.customize((connector) -> { - String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 9999999999999 222\r\n" + - "GET /path HTTP/1.1\n" + - "Host: server:80\n" + - "Connection: close\n" + - "\n"); - assertNull(response); + ProxyConnectionFactory factory = (ProxyConnectionFactory)connector.getConnectionFactory("[proxy]"); + factory.setMaxProxyHeader(11); // just one byte short + }); + String proxy = + // Preamble + "0D0A0D0A000D0A515549540A" + + + // V2, PROXY + "21" + + + // 0x1 : AF_INET 0x1 : STREAM. + "11" + + + // Address length is 2*4 + 2*2 = 12 bytes. + // length of remaining header (4+4+2+2 = 12) + "000C" + + + // uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port; + "C0A80001" + + "7f000001" + + "3039" + + "1F90"; + String http = "GET /path HTTP/1.1\n" + + "Host: server:80\n" + + "Connection: close\n" + + "\n"; + + String response = p.sendRequestWaitingForResponse(TypeUtil.fromHexString(proxy), http.getBytes(StandardCharsets.US_ASCII)); + + assertThat(response, Matchers.is(Matchers.nullValue())); + } + + abstract static class RequestProcessor + { + protected LocalConnector _connector; + private Server _server; + + public RequestProcessor() + { + _server = new Server(); + HttpConnectionFactory http = new HttpConnectionFactory(); + http.getHttpConfiguration().setRequestHeaderSize(1024); + http.getHttpConfiguration().setResponseHeaderSize(1024); + ProxyConnectionFactory proxy = new ProxyConnectionFactory(HttpVersion.HTTP_1_1.asString()); + + _connector = new LocalConnector(_server, null, null, null, 1, proxy, http); + _connector.setIdleTimeout(1000); + _server.addConnector(_connector); + _server.setHandler(new DumpHandler()); + ErrorHandler eh = new ErrorHandler(); + eh.setServer(_server); + _server.addBean(eh); + } + + public RequestProcessor customize(Consumer consumer) + { + consumer.accept(_connector); + return this; + } + + public final String sendRequestWaitingForResponse(String request) throws Exception + { + return sendRequestWaitingForResponse(request.getBytes(StandardCharsets.US_ASCII)); + } + + public final String sendRequestWaitingForResponse(byte[]... requests) throws Exception + { + try + { + _server.start(); + return process(requests); + } + finally + { + destroy(); + } + } + + protected abstract String process(byte[]... requests) throws Exception; + + private void destroy() throws Exception + { + _server.stop(); + _server.join(); } } - @Test - public void testMissingField() throws Exception + static Stream requestProcessors() { - String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 222\r\n" + - "GET /path HTTP/1.1\n" + - "Host: server:80\n" + - "Connection: close\n" + - "\n"); - assertNull(response); + return Stream.of( + Arguments.of(new RequestProcessor() + { + @Override + public String process(byte[]... requests) throws Exception + { + LocalConnector.LocalEndPoint endPoint = _connector.connect(); + for (byte[] request : requests) + { + endPoint.addInput(ByteBuffer.wrap(request)); + } + return endPoint.getResponse(); + } + + @Override + public String toString() + { + return "All bytes at once"; + } + }), + Arguments.of(new RequestProcessor() + { + @Override + public String process(byte[]... requests) throws Exception + { + LocalConnector.LocalEndPoint endPoint = _connector.connect(); + for (byte[] request : requests) + { + for (byte b : request) + { + endPoint.addInput(ByteBuffer.wrap(new byte[]{b})); + } + } + return endPoint.getResponse(); + } + + @Override + public String toString() + { + return "Byte by byte"; + } + }) + ); } - @Test - public void testHTTP() throws Exception - { - String response = _connector.getResponse( - "GET /path HTTP/1.1\n" + - "Host: server:80\n" + - "Connection: close\n" + - "\n"); - assertNull(response); - } } - -