Implement detector connection factory with protocol detection mechanism

Signed-off-by: Ludovic Orban <lorban@bitronix.be>
This commit is contained in:
Ludovic Orban 2020-01-23 16:46:15 +01:00
parent ecd0fe97f7
commit 75b4719592
9 changed files with 2065 additions and 689 deletions

View File

@ -20,6 +20,7 @@ package org.eclipse.jetty.server;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.eclipse.jetty.io.AbstractConnection;
@ -87,6 +88,21 @@ public abstract class AbstractConnectionFactory extends ContainerLifeCycle imple
_inputbufferSize = size;
}
protected String findNextProtocol(Connector connector)
{
String nextProtocol = null;
for (Iterator<String> it = connector.getProtocols().iterator(); it.hasNext(); )
{
String protocol = it.next();
if (getProtocol().equalsIgnoreCase(protocol))
{
nextProtocol = it.hasNext() ? it.next() : null;
break;
}
}
return nextProtocol;
}
protected AbstractConnection configure(AbstractConnection connection, Connector connector, EndPoint endPoint)
{
connection.setInputBufferSize(getInputBufferSize());

View File

@ -18,6 +18,7 @@
package org.eclipse.jetty.server;
import java.nio.ByteBuffer;
import java.util.List;
import org.eclipse.jetty.http.BadMessageException;
@ -85,4 +86,43 @@ public interface ConnectionFactory
*/
Connection upgradeConnection(Connector connector, EndPoint endPoint, MetaData.Request upgradeRequest, HttpFields responseFields) throws BadMessageException;
}
/**
* <p>Connections created by this factory MUST implement {@link Connection.UpgradeTo}.</p>
*/
interface Detecting extends ConnectionFactory
{
/**
* The possible outcomes of the {@link #detect(ByteBuffer)} method.
*/
enum Detection
{
/**
* A {@link Detecting} can work with the given bytes.
*/
RECOGNIZED,
/**
* A {@link Detecting} cannot work with the given bytes.
*/
NOT_RECOGNIZED,
/**
* A {@link Detecting} requires more bytes to make a decision.
*/
NEED_MORE_BYTES
}
/**
* <p>Check the bytes in the given {@code buffer} to figure out if this {@link Detecting} instance
* can work with them or not.</p>
* <p>The {@code buffer} MUST be left untouched by this method: bytes MUST NOT be consumed and MUST NOT be modified.</p>
* @param buffer the buffer.
* @return One of:
* <ul>
* <li>{@link Detection#RECOGNIZED} if this {@link Detecting} instance can work with the bytes in the buffer</li>
* <li>{@link Detection#NOT_RECOGNIZED} if this {@link Detecting} instance cannot work with the bytes in the buffer</li>
* <li>{@link Detection#NEED_MORE_BYTES} if this {@link Detecting} instance requires more bytes to make a decision</li>
* </ul>
*/
Detection detect(ByteBuffer buffer);
}
}

View File

@ -0,0 +1,294 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.server;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.eclipse.jetty.io.AbstractConnection;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
/**
* A {@link ConnectionFactory} combining multiple {@link Detecting} instances that will upgrade to
* the first one recognizing the bytes in the buffer.
*/
public class DetectorConnectionFactory extends AbstractConnectionFactory implements ConnectionFactory.Detecting
{
private static final Logger LOG = Log.getLogger(DetectorConnectionFactory.class);
private final List<Detecting> _detectingConnectionFactories;
/**
* <p>When the first bytes are not recognized by the {@code detectingConnectionFactories}, the default behavior is to
* upgrade to the protocol returned by {@link #findNextProtocol(Connector)}.</p>
* @param detectingConnectionFactories the {@link Detecting} instances.
*/
public DetectorConnectionFactory(Detecting... detectingConnectionFactories)
{
super(toProtocolString(detectingConnectionFactories));
_detectingConnectionFactories = Arrays.asList(detectingConnectionFactories);
for (Detecting detectingConnectionFactory : detectingConnectionFactories)
{
addBean(detectingConnectionFactory);
}
}
private static String toProtocolString(Detecting... detectingConnectionFactories)
{
if (detectingConnectionFactories.length == 0)
throw new IllegalArgumentException("At least one detecting instance is required");
// remove protocol duplicates while keeping their ordering -> use LinkedHashSet
LinkedHashSet<String> protocols = Arrays.stream(detectingConnectionFactories).map(ConnectionFactory::getProtocol).collect(Collectors.toCollection(LinkedHashSet::new));
String protocol = String.join("|", protocols);
if (LOG.isDebugEnabled())
LOG.debug("Detector generated protocol name : {}", protocol);
return protocol;
}
/**
* Performs a detection using multiple {@link ConnectionFactory.Detecting} instances and returns the aggregated outcome.
* @param buffer the buffer to perform a detection against.
* @return A {@link Detecting.Detection} value with the detection outcome of the {@code detectingConnectionFactories}.
*/
@Override
public Detecting.Detection detect(ByteBuffer buffer)
{
if (LOG.isDebugEnabled())
LOG.debug("Attempting detection from buffer {} using {}", buffer, _detectingConnectionFactories);
boolean needMoreBytes = true;
for (Detecting detectingConnectionFactory : _detectingConnectionFactories)
{
Detecting.Detection detection = detectingConnectionFactory.detect(buffer);
if (detection == Detecting.Detection.RECOGNIZED)
{
if (LOG.isDebugEnabled())
LOG.debug("Detection recognized bytes from buffer {} using {}", buffer, _detectingConnectionFactories);
return Detecting.Detection.RECOGNIZED;
}
needMoreBytes &= detection == Detection.NEED_MORE_BYTES;
}
if (LOG.isDebugEnabled())
LOG.debug("Detection {} from buffer {} using {}", (needMoreBytes ? "requires more bytes" : "not recognized"), buffer, _detectingConnectionFactories);
return needMoreBytes ? Detection.NEED_MORE_BYTES : Detection.NOT_RECOGNIZED;
}
/**
* Utility method that performs an upgrade to the specified connection factory, disposing of the given resources when needed.
* @param connectionFactory the connection factory to upgrade to.
* @param connector the connector.
* @param endPoint the endpoint.
*/
protected static void upgradeToConnectionFactory(ConnectionFactory connectionFactory, Connector connector, EndPoint endPoint) throws IllegalStateException
{
if (LOG.isDebugEnabled())
LOG.debug("Upgrading to connection factory {}", connectionFactory);
if (connectionFactory == null)
throw new IllegalStateException("Cannot upgrade: connection factory must not be null for " + endPoint);
Connection nextConnection = connectionFactory.newConnection(connector, endPoint);
if (!(nextConnection instanceof Connection.UpgradeTo))
throw new IllegalStateException("Cannot upgrade: " + nextConnection + " does not implement " + Connection.UpgradeTo.class.getName() + " for " + endPoint);
endPoint.upgrade(nextConnection);
if (LOG.isDebugEnabled())
LOG.debug("Upgraded to connection factory {} and released buffer", connectionFactory);
}
/**
* <p>Callback method called when detection was unsuccessful.
* This implementation upgrades to the protocol returned by {@link #findNextProtocol(Connector)}.</p>
* @param connector the connector.
* @param endPoint the endpoint.
* @param buffer the buffer.
*/
protected void nextProtocol(Connector connector, EndPoint endPoint, ByteBuffer buffer) throws IllegalStateException
{
String nextProtocol = findNextProtocol(connector);
if (LOG.isDebugEnabled())
LOG.debug("Detector {} detection unsuccessful, found '{}' as the next protocol to upgrade to", getProtocol(), nextProtocol);
if (nextProtocol == null)
throw new IllegalStateException("Cannot find protocol following '" + getProtocol() + "' in connector's protocol list " + connector.getProtocols() + " for " + endPoint);
upgradeToConnectionFactory(connector.getConnectionFactory(nextProtocol), connector, endPoint);
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{
return configure(new DetectorConnection(endPoint, connector), connector, endPoint);
}
private class DetectorConnection extends AbstractConnection implements Connection.UpgradeFrom, Connection.UpgradeTo
{
private final Connector _connector;
private final ByteBuffer _buffer;
private DetectorConnection(EndPoint endp, Connector connector)
{
super(endp, connector.getExecutor());
_connector = connector;
_buffer = connector.getByteBufferPool().acquire(getInputBufferSize(), true);
}
@Override
public void onUpgradeTo(ByteBuffer prefilled)
{
if (LOG.isDebugEnabled())
LOG.debug("Detector {} copying prefilled buffer {}", getProtocol(), BufferUtil.toDetailString(prefilled));
if (BufferUtil.hasContent(prefilled))
BufferUtil.append(_buffer, prefilled);
}
@Override
public ByteBuffer onUpgradeFrom()
{
return _buffer;
}
@Override
public void onOpen()
{
super.onOpen();
if (!detectAndUpgrade())
fillInterested();
}
@Override
public void onFillable()
{
try
{
while (BufferUtil.space(_buffer) > 0)
{
// Read data
int fill = getEndPoint().fill(_buffer);
if (LOG.isDebugEnabled())
LOG.debug("Detector {} filled buffer with {} bytes", getProtocol(), fill);
if (fill < 0)
{
_connector.getByteBufferPool().release(_buffer);
getEndPoint().shutdownOutput();
return;
}
if (fill == 0)
{
fillInterested();
return;
}
if (detectAndUpgrade())
return;
}
// all Detecting instances want more bytes than this buffer can store
LOG.warn("Detector {} failed to detect upgrade target on {} for {}", getProtocol(), _detectingConnectionFactories, getEndPoint());
releaseAndClose();
}
catch (Throwable x)
{
LOG.warn("Detector {} error for {}", getProtocol(), getEndPoint(), x);
releaseAndClose();
}
}
/**
* @return true when upgrade was performed, false otherwise.
*/
private boolean detectAndUpgrade()
{
if (LOG.isDebugEnabled())
LOG.debug("Detector {} performing detection with {} bytes", getProtocol(), _buffer.remaining());
boolean notRecognized = true;
for (Detecting detectingConnectionFactory : _detectingConnectionFactories)
{
Detecting.Detection detection = detectingConnectionFactory.detect(_buffer);
if (LOG.isDebugEnabled())
LOG.debug("Detector {} performed detection from {} with {} which returned {}", getProtocol(), BufferUtil.toDetailString(_buffer), detectingConnectionFactory, detection);
if (detection == Detecting.Detection.RECOGNIZED)
{
try
{
// This DetectingConnectionFactory recognized those bytes -> upgrade to the next one.
Connection nextConnection = detectingConnectionFactory.newConnection(_connector, getEndPoint());
if (!(nextConnection instanceof UpgradeTo))
throw new IllegalStateException("Cannot upgrade: " + nextConnection + " does not implement " + UpgradeTo.class.getName());
getEndPoint().upgrade(nextConnection);
if (LOG.isDebugEnabled())
LOG.debug("Detector {} upgraded to {}", getProtocol(), nextConnection);
return true;
}
catch (DetectionFailureException e)
{
// It's just bubbling up from a nested Detector, so it's already handled, just rethrow it.
if (LOG.isDebugEnabled())
LOG.debug("Detector {} failed to upgrade, rethrowing", getProtocol(), e);
throw e;
}
catch (Exception e)
{
// Two reasons that can make us end up here:
// 1) detectingConnectionFactory.newConnection() failed? probably because it cannot find the next protocol
// 2) nextConnection is not instanceof UpgradeTo
// -> release the resources before rethrowing as DetectionFailureException
if (LOG.isDebugEnabled())
LOG.debug("Detector {} failed to upgrade", getProtocol());
releaseAndClose();
throw new DetectionFailureException(e);
}
}
notRecognized &= detection == Detecting.Detection.NOT_RECOGNIZED;
}
if (notRecognized)
{
// No DetectingConnectionFactory recognized those bytes -> call unsuccessful detection callback.
if (LOG.isDebugEnabled())
LOG.debug("Detector {} failed to detect a known protocol, falling back to nextProtocol()", getProtocol());
nextProtocol(_connector, getEndPoint(), _buffer);
if (LOG.isDebugEnabled())
LOG.debug("Detector {} call to nextProtocol() succeeded, assuming upgrade performed", getProtocol());
return true;
}
return false;
}
private void releaseAndClose()
{
if (LOG.isDebugEnabled())
LOG.debug("Detector {} releasing buffer and closing", getProtocol());
_connector.getByteBufferPool().release(_buffer);
close();
}
}
private static class DetectionFailureException extends RuntimeException
{
public DetectionFailureException(Throwable cause)
{
super(cause);
}
}
}

View File

@ -18,14 +18,10 @@
package org.eclipse.jetty.server;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import org.eclipse.jetty.io.AbstractConnection;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
@ -34,62 +30,70 @@ import org.eclipse.jetty.util.log.Logger;
* <p>A ConnectionFactory whose connections detect whether the first bytes are
* TLS bytes and upgrades to either a TLS connection or to another configurable
* connection.</p>
*
* @deprecated Use {@link DetectorConnectionFactory} with a {@link SslConnectionFactory} instead.
*/
public class OptionalSslConnectionFactory extends AbstractConnectionFactory
@Deprecated
public class OptionalSslConnectionFactory extends DetectorConnectionFactory
{
private static final Logger LOG = Log.getLogger(OptionalSslConnection.class);
private static final int TLS_ALERT_FRAME_TYPE = 0x15;
private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16;
private static final int TLS_MAJOR_VERSION = 3;
private final SslConnectionFactory sslConnectionFactory;
private final String otherProtocol;
private static final Logger LOG = Log.getLogger(OptionalSslConnectionFactory.class);
private final String _nextProtocol;
/**
* <p>Creates a new ConnectionFactory whose connections can upgrade to TLS or another protocol.</p>
* <p>If {@code otherProtocol} is {@code null}, and the first bytes are not TLS, then
* {@link #otherProtocol(ByteBuffer, EndPoint)} is called.</p>
*
* @param sslConnectionFactory The SslConnectionFactory to use if the first bytes are TLS
* @param otherProtocol the protocol of the ConnectionFactory to use if the first bytes are not TLS,
* @param sslConnectionFactory The {@link SslConnectionFactory} to use if the first bytes are TLS
* @param nextProtocol the protocol of the {@link ConnectionFactory} to use if the first bytes are not TLS,
* or null to explicitly handle the non-TLS case
*/
public OptionalSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String otherProtocol)
public OptionalSslConnectionFactory(SslConnectionFactory sslConnectionFactory, String nextProtocol)
{
super("ssl|other");
this.sslConnectionFactory = sslConnectionFactory;
this.otherProtocol = otherProtocol;
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{
return configure(new OptionalSslConnection(endPoint, connector), connector, endPoint);
super(sslConnectionFactory);
_nextProtocol = nextProtocol;
}
/**
* @param buffer The buffer with the first bytes of the connection
* @return whether the bytes seem TLS bytes
*/
protected boolean seemsTLS(ByteBuffer buffer)
{
int tlsFrameType = buffer.get(0) & 0xFF;
int tlsMajorVersion = buffer.get(1) & 0xFF;
return (tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE) && tlsMajorVersion == TLS_MAJOR_VERSION;
}
/**
* <p>Callback method invoked when {@code otherProtocol} is {@code null}
* and the first bytes are not TLS.</p>
* <p>Callback method invoked when the detected bytes are not TLS.</p>
* <p>This typically happens when a client is trying to connect to a TLS
* port using the {@code http} scheme (and not the {@code https} scheme).</p>
*
* @param connector The connector object
* @param endPoint The connection EndPoint object
* @param buffer The buffer with the first bytes of the connection
*/
protected void nextProtocol(Connector connector, EndPoint endPoint, ByteBuffer buffer)
{
if (LOG.isDebugEnabled())
LOG.debug("OptionalSSL TLS detection unsuccessful, attempting to upgrade to {}", _nextProtocol);
if (_nextProtocol != null)
{
ConnectionFactory connectionFactory = connector.getConnectionFactory(_nextProtocol);
if (connectionFactory == null)
throw new IllegalStateException("Cannot find protocol '" + _nextProtocol + "' in connector's protocol list " + connector.getProtocols() + " for " + endPoint);
upgradeToConnectionFactory(connectionFactory, connector, endPoint);
}
else
{
otherProtocol(buffer, endPoint);
}
}
/**
* <p>Legacy callback method invoked when {@code nextProtocol} is {@code null}
* and the first bytes are not TLS.</p>
* <p>This typically happens when a client is trying to connect to a TLS
* port using the {@code http} scheme (and not the {@code https} scheme).</p>
* <p>This method is kept around for backward compatibility.</p>
*
* @param buffer The buffer with the first bytes of the connection
* @param endPoint The connection EndPoint object
* @see #seemsTLS(ByteBuffer)
* @deprecated Override {@link #nextProtocol(Connector, EndPoint, ByteBuffer)} instead.
*/
@Deprecated
protected void otherProtocol(ByteBuffer buffer, EndPoint endPoint)
{
LOG.warn("Detected non-TLS bytes, but no other protocol to upgrade to for {}", endPoint);
// There are always at least 2 bytes.
int byte1 = buffer.get(0) & 0xFF;
int byte2 = buffer.get(1) & 0xFF;
@ -122,105 +126,4 @@ public class OptionalSslConnectionFactory extends AbstractConnectionFactory
endPoint.close();
}
}
private class OptionalSslConnection extends AbstractConnection implements Connection.UpgradeFrom
{
private final Connector connector;
private final ByteBuffer buffer;
public OptionalSslConnection(EndPoint endPoint, Connector connector)
{
super(endPoint, connector.getExecutor());
this.connector = connector;
this.buffer = BufferUtil.allocateDirect(1536);
}
@Override
public void onOpen()
{
super.onOpen();
fillInterested();
}
@Override
public void onFillable()
{
try
{
while (true)
{
int filled = getEndPoint().fill(buffer);
if (filled > 0)
{
// Always have at least 2 bytes.
if (BufferUtil.length(buffer) >= 2)
{
upgrade(buffer);
break;
}
}
else if (filled == 0)
{
fillInterested();
break;
}
else
{
close();
break;
}
}
}
catch (IOException x)
{
LOG.warn(x);
close();
}
}
@Override
public ByteBuffer onUpgradeFrom()
{
return buffer;
}
private void upgrade(ByteBuffer buffer)
{
if (LOG.isDebugEnabled())
LOG.debug("Read {}", BufferUtil.toDetailString(buffer));
EndPoint endPoint = getEndPoint();
if (seemsTLS(buffer))
{
if (LOG.isDebugEnabled())
LOG.debug("Detected TLS bytes, upgrading to {}", sslConnectionFactory);
endPoint.upgrade(sslConnectionFactory.newConnection(connector, endPoint));
}
else
{
if (otherProtocol != null)
{
ConnectionFactory connectionFactory = connector.getConnectionFactory(otherProtocol);
if (connectionFactory != null)
{
if (LOG.isDebugEnabled())
LOG.debug("Detected non-TLS bytes, upgrading to {}", connectionFactory);
Connection next = connectionFactory.newConnection(connector, endPoint);
endPoint.upgrade(next);
}
else
{
LOG.warn("Missing {} {} in {}", otherProtocol, ConnectionFactory.class.getSimpleName(), connector);
close();
}
}
else
{
if (LOG.isDebugEnabled())
LOG.debug("Detected non-TLS bytes, but no other protocol to upgrade to");
otherProtocol(buffer, endPoint);
}
}
}
}
}

View File

@ -18,6 +18,7 @@
package org.eclipse.jetty.server;
import java.nio.ByteBuffer;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLSession;
@ -31,8 +32,12 @@ import org.eclipse.jetty.util.annotation.Name;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.ssl.SslContextFactory;
public class SslConnectionFactory extends AbstractConnectionFactory
public class SslConnectionFactory extends AbstractConnectionFactory implements ConnectionFactory.Detecting
{
private static final int TLS_ALERT_FRAME_TYPE = 0x15;
private static final int TLS_HANDSHAKE_FRAME_TYPE = 0x16;
private static final int TLS_MAJOR_VERSION = 3;
private final SslContextFactory _sslContextFactory;
private final String _nextProtocol;
private boolean _directBuffersForEncryption = false;
@ -99,6 +104,17 @@ public class SslConnectionFactory extends AbstractConnectionFactory
setInputBufferSize(session.getPacketBufferSize());
}
@Override
public Detection detect(ByteBuffer buffer)
{
if (buffer.remaining() < 2)
return Detection.NEED_MORE_BYTES;
int tlsFrameType = buffer.get(0) & 0xFF;
int tlsMajorVersion = buffer.get(1) & 0xFF;
boolean seemsSsl = (tlsFrameType == TLS_HANDSHAKE_FRAME_TYPE || tlsFrameType == TLS_ALERT_FRAME_TYPE) && tlsMajorVersion == TLS_MAJOR_VERSION;
return seemsSsl ? Detection.RECOGNIZED : Detection.NOT_RECOGNIZED;
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{

View File

@ -0,0 +1,715 @@
//
// ========================================================================
// Copyright (c) 1995-2020 Mort Bay Consulting Pty Ltd and others.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.server;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.net.ssl.SSLSocketFactory;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.io.AbstractConnection;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.toolchain.test.MavenTestingUtils;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.TypeUtil;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
public class DetectorConnectionTest
{
private Server _server;
private static String inputStreamToString(InputStream is) throws IOException
{
StringBuilder sb = new StringBuilder();
BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.US_ASCII));
while (true)
{
String line = reader.readLine();
if (line == null)
{
// remove the last '\n'
if (sb.length() != 0)
sb.deleteCharAt(sb.length() - 1);
break;
}
sb.append(line).append('\n');
}
return sb.length() == 0 ? null : sb.toString();
}
private String getResponse(String request) throws Exception
{
return getResponse(request.getBytes(StandardCharsets.US_ASCII));
}
private String getResponse(byte[]... requests) throws Exception
{
try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort()))
{
for (byte[] request : requests)
{
socket.getOutputStream().write(request);
}
return inputStreamToString(socket.getInputStream());
}
}
private String getResponseOverSsl(String request) throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
sslContextFactory.start();
SSLSocketFactory socketFactory = sslContextFactory.getSslContext().getSocketFactory();
try (Socket socket = socketFactory.createSocket(_server.getURI().getHost(), _server.getURI().getPort()))
{
socket.getOutputStream().write(request.getBytes(StandardCharsets.US_ASCII));
return inputStreamToString(socket.getInputStream());
}
finally
{
sslContextFactory.stop();
}
}
private void start(ConnectionFactory... connectionFactories) throws Exception
{
_server = new Server();
_server.addConnector(new ServerConnector(_server, 1, 1, connectionFactories));
_server.setHandler(new DumpHandler());
_server.start();
}
@AfterEach
public void destroy() throws Exception
{
if (_server != null)
_server.stop();
}
@Test
public void testConnectionClosedDuringDetection() throws Exception
{
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy);
start(detector, http);
try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort()))
{
socket.getOutputStream().write("PR".getBytes(StandardCharsets.US_ASCII));
Thread.sleep(100); // make sure the onFillable callback gets called
socket.getOutputStream().write("OX".getBytes(StandardCharsets.US_ASCII));
socket.getOutputStream().close();
assertThrows(SocketException.class, () -> socket.getInputStream().read());
}
}
@Test
public void testConnectionClosedDuringProxyV1Handling() throws Exception
{
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy);
start(detector, http);
try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort()))
{
socket.getOutputStream().write("PROXY".getBytes(StandardCharsets.US_ASCII));
Thread.sleep(100); // make sure the onFillable callback gets called
socket.getOutputStream().write(" ".getBytes(StandardCharsets.US_ASCII));
socket.getOutputStream().close();
assertThrows(SocketException.class, () -> socket.getInputStream().read());
}
}
@Test
public void testConnectionClosedDuringProxyV2HandlingFixedLengthPart() throws Exception
{
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy);
start(detector, http);
try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort()))
{
socket.getOutputStream().write(TypeUtil.fromHexString("0D0A0D0A000D0A515549540A")); // proxy V2 Preamble
Thread.sleep(100); // make sure the onFillable callback gets called
socket.getOutputStream().write(TypeUtil.fromHexString("21")); // V2, PROXY
socket.getOutputStream().close();
assertThrows(SocketException.class, () -> socket.getInputStream().read());
}
}
@Test
public void testConnectionClosedDuringProxyV2HandlingDynamicLengthPart() throws Exception
{
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy);
start(detector, http);
try (Socket socket = new Socket(_server.getURI().getHost(), _server.getURI().getPort()))
{
socket.getOutputStream().write(TypeUtil.fromHexString(
// proxy V2 Preamble
"0D0A0D0A000D0A515549540A" +
// V2, PROXY
"21" +
// 0x1 : AF_INET 0x1 : STREAM.
"11" +
// Address length is 2*4 + 2*2 = 12 bytes.
// length of remaining header (4+4+2+2 = 12)
"000C"
));
Thread.sleep(100); // make sure the onFillable callback gets called
socket.getOutputStream().write(TypeUtil.fromHexString(
// uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port;
"C0A80001" // 8080
));
socket.getOutputStream().close();
assertThrows(SocketException.class, () -> socket.getInputStream().read());
}
}
@Test
public void testDetectingSslProxyToHttpNoSslWithProxy() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy);
start(detector, http);
String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(request);
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
assertThat(response, Matchers.containsString("local=5.6.7.8:222"));
assertThat(response, Matchers.containsString("remote=1.2.3.4:111"));
}
@Test
public void testDetectingSslProxyToHttpWithSslNoProxy() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy);
start(detector, http);
String request = "GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponseOverSsl(request);
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
}
@Test
public void testDetectingSslProxyToHttpWithSslWithProxy() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy);
start(detector, http);
String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponseOverSsl(request);
// SSL matched, so the upgrade was made to HTTP which does not understand the proxy request
assertThat(response, Matchers.containsString("HTTP/1.1 400"));
}
@Test
public void testDetectionUnsuccessfulUpgradesToNextProtocol() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy);
start(detector, http);
String request = "GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(request);
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
}
@Test
void testDetectorToNextDetector() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
DetectorConnectionFactory proxyDetector = new DetectorConnectionFactory(proxy);
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, proxyDetector.getProtocol());
DetectorConnectionFactory sslDetector = new DetectorConnectionFactory(ssl);
start(sslDetector, proxyDetector, http);
String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponseOverSsl(request);
// SSL matched, so the upgrade was made to proxy which itself upgraded to HTTP
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
assertThat(response, Matchers.containsString("local=5.6.7.8:222"));
assertThat(response, Matchers.containsString("remote=1.2.3.4:111"));
}
@Test
void testDetectorWithDetectionUnsuccessful() throws Exception
{
AtomicBoolean detectionSuccessful = new AtomicBoolean(true);
ProxyConnectionFactory proxy = new ProxyConnectionFactory(HttpVersion.HTTP_1_1.asString());
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy)
{
@Override
protected void nextProtocol(Connector connector, EndPoint endPoint, ByteBuffer buffer)
{
if (!detectionSuccessful.compareAndSet(true, false))
throw new AssertionError("DetectionUnsuccessful callback should only have been called once");
// omitting this will leak the buffer
connector.getByteBufferPool().release(buffer);
Callback.Completable completable = new Callback.Completable();
endPoint.write(completable, ByteBuffer.wrap("No upgrade for you".getBytes(StandardCharsets.US_ASCII)));
completable.whenComplete((r, x) -> endPoint.close());
}
};
HttpConnectionFactory http = new HttpConnectionFactory();
start(detector, http);
String request = "GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(request);
assertEquals("No upgrade for you", response);
assertFalse(detectionSuccessful.get());
}
@Test
void testDetectorWithProxyThatHasNoNextProto() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory();
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl, proxy);
start(detector, http);
String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(request);
// ProxyConnectionFactory has no next protocol -> it cannot upgrade
assertThat(response, Matchers.nullValue());
}
@Test
void testOptionalSsl() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConnectionFactory http = new HttpConnectionFactory();
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl);
start(detector, http);
String request =
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String clearTextResponse = getResponse(request);
String sslResponse = getResponseOverSsl(request);
// both clear text and SSL can be responded to just fine
assertThat(clearTextResponse, Matchers.containsString("HTTP/1.1 200"));
assertThat(sslResponse, Matchers.containsString("HTTP/1.1 200"));
}
@Test
void testDetectorThatHasNoConfiguredNextProto() throws Exception
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, HttpVersion.HTTP_1_1.asString());
DetectorConnectionFactory detector = new DetectorConnectionFactory(ssl);
start(detector);
String request =
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(request);
assertThat(response, Matchers.nullValue());
}
@Test
void testDetectorWithNextProtocolThatDoesNotExist() throws Exception
{
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory("does-not-exist");
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy);
start(detector, http);
String proxyReq =
// proxy V2 Preamble
"0D0A0D0A000D0A515549540A" +
// V2, PROXY
"21" +
// 0x1 : AF_INET 0x1 : STREAM.
"11" +
// Address length is 2*4 + 2*2 = 12 bytes.
// length of remaining header (4+4+2+2 = 12)
"000C" +
// uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port;
"C0A80001" + // 192.168.0.1
"7f000001" + // 127.0.0.1
"3039" + // 12345
"1F90"; // 8080
String httpReq =
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(TypeUtil.fromHexString(proxyReq), httpReq.getBytes(StandardCharsets.US_ASCII));
assertThat(response, Matchers.nullValue());
}
@Test
void testDetectingWithNextProtocolThatDoesNotImplementUpgradeTo() throws Exception
{
ConnectionFactory.Detecting noUpgradeTo = new ConnectionFactory.Detecting()
{
@Override
public Detection detect(ByteBuffer buffer)
{
return Detection.RECOGNIZED;
}
@Override
public String getProtocol()
{
return "noUpgradeTo";
}
@Override
public List<String> getProtocols()
{
return Collections.singletonList(getProtocol());
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{
return new AbstractConnection(null, connector.getExecutor())
{
@Override
public void onFillable()
{
}
};
}
};
HttpConnectionFactory http = new HttpConnectionFactory();
DetectorConnectionFactory detector = new DetectorConnectionFactory(noUpgradeTo);
start(detector, http);
String proxyReq =
// proxy V2 Preamble
"0D0A0D0A000D0A515549540A" +
// V2, PROXY
"21" +
// 0x1 : AF_INET 0x1 : STREAM.
"11" +
// Address length is 2*4 + 2*2 = 12 bytes.
// length of remaining header (4+4+2+2 = 12)
"000C" +
// uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port;
"C0A80001" + // 192.168.0.1
"7f000001" + // 127.0.0.1
"3039" + // 12345
"1F90"; // 8080
String httpReq =
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(TypeUtil.fromHexString(proxyReq), httpReq.getBytes(StandardCharsets.US_ASCII));
assertThat(response, Matchers.nullValue());
}
@Test
void testDetectorWithNextProtocolThatDoesNotImplementUpgradeTo() throws Exception
{
ConnectionFactory noUpgradeTo = new ConnectionFactory()
{
@Override
public String getProtocol()
{
return "noUpgradeTo";
}
@Override
public List<String> getProtocols()
{
return Collections.singletonList(getProtocol());
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{
return new AbstractConnection(null, connector.getExecutor())
{
@Override
public void onFillable()
{
}
};
}
};
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
DetectorConnectionFactory detector = new DetectorConnectionFactory(proxy);
start(detector, noUpgradeTo);
String request =
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = getResponse(request);
assertThat(response, Matchers.nullValue());
}
@Test
void testGeneratedProtocolNames()
{
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
ProxyConnectionFactory proxy = new ProxyConnectionFactory(HttpVersion.HTTP_1_1.asString());
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, HttpVersion.HTTP_1_1.asString());
assertEquals("SSL|proxy", new DetectorConnectionFactory(ssl, proxy).getProtocol());
assertEquals("proxy|SSL", new DetectorConnectionFactory(proxy, ssl).getProtocol());
}
@Test
void testDetectorWithNoDetectingFails()
{
assertThrows(IllegalArgumentException.class, DetectorConnectionFactory::new);
}
@Test
void testExerciseDetectorNotEnoughBytes() throws Exception
{
ConnectionFactory.Detecting detectingNeverRecognizes = new ConnectionFactory.Detecting()
{
@Override
public Detection detect(ByteBuffer buffer)
{
return Detection.NOT_RECOGNIZED;
}
@Override
public String getProtocol()
{
return "nevergood";
}
@Override
public List<String> getProtocols()
{
throw new AssertionError();
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{
throw new AssertionError();
}
};
ConnectionFactory.Detecting detectingAlwaysNeedMoreBytes = new ConnectionFactory.Detecting()
{
@Override
public Detection detect(ByteBuffer buffer)
{
return Detection.NEED_MORE_BYTES;
}
@Override
public String getProtocol()
{
return "neverenough";
}
@Override
public List<String> getProtocols()
{
throw new AssertionError();
}
@Override
public Connection newConnection(Connector connector, EndPoint endPoint)
{
throw new AssertionError();
}
};
DetectorConnectionFactory detector = new DetectorConnectionFactory(detectingNeverRecognizes, detectingAlwaysNeedMoreBytes);
HttpConnectionFactory http = new HttpConnectionFactory();
start(detector, http);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < 32768; i++)
{
sb.append("AAAA");
}
String request = sb.toString();
String response = getResponse(request);
assertThat(response, Matchers.nullValue());
}
}

View File

@ -38,6 +38,7 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
public class OptionalSslConnectionTest
{
@ -60,7 +61,7 @@ public class OptionalSslConnectionTest
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
OptionalSslConnectionFactory sslOrOther = configFn.apply(ssl);
connector = new ServerConnector(server, 1, 1, sslOrOther, ssl, http);
connector = new ServerConnector(server, 1, 1, sslOrOther, http);
server.addConnector(connector);
server.setHandler(handler);
@ -204,6 +205,47 @@ public class OptionalSslConnectionTest
}
}
@Test
void testNextProtocolIsNotNullButNotConfiguredEither() throws Exception
{
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
String keystore = MavenTestingUtils.getTestResourceFile("keystore").getAbsolutePath();
SslContextFactory sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setKeyStorePath(keystore);
sslContextFactory.setKeyStorePassword("storepwd");
sslContextFactory.setKeyManagerPassword("keypwd");
HttpConfiguration httpConfig = new HttpConfiguration();
HttpConnectionFactory http = new HttpConnectionFactory(httpConfig);
SslConnectionFactory ssl = new SslConnectionFactory(sslContextFactory, http.getProtocol());
OptionalSslConnectionFactory optSsl = new OptionalSslConnectionFactory(ssl, "no-such-protocol");
connector = new ServerConnector(server, 1, 1, optSsl, http);
server.addConnector(connector);
server.setHandler(new EmptyServerHandler());
server.start();
try (Socket socket = new Socket(server.getURI().getHost(), server.getURI().getPort()))
{
OutputStream sslOutput = socket.getOutputStream();
String request =
"GET / HTTP/1.1\r\n" +
"Host: localhost\r\n" +
"\r\n";
byte[] requestBytes = request.getBytes(StandardCharsets.US_ASCII);
sslOutput.write(requestBytes);
sslOutput.flush();
socket.setSoTimeout(5000);
InputStream sslInput = socket.getInputStream();
HttpTester.Response response = HttpTester.parseResponse(sslInput);
assertNull(response);
}
}
private static class EmptyServerHandler extends AbstractHandler
{
@Override

View File

@ -18,75 +18,94 @@
package org.eclipse.jetty.server;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.function.Consumer;
import java.util.stream.Stream;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.server.handler.ErrorHandler;
import org.eclipse.jetty.toolchain.test.Net;
import org.eclipse.jetty.util.TypeUtil;
import org.eclipse.jetty.util.log.StacklessLogging;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNull;
public class ProxyConnectionTest
{
private Server _server;
private LocalConnector _connector;
@BeforeEach
public void init() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
public void testBadCRLF(RequestProcessor p) throws Exception
{
_server = new Server();
HttpConnectionFactory http = new HttpConnectionFactory();
http.getHttpConfiguration().setRequestHeaderSize(1024);
http.getHttpConfiguration().setResponseHeaderSize(1024);
ProxyConnectionFactory proxy = new ProxyConnectionFactory();
_connector = new LocalConnector(_server, null, null, null, 1, proxy, http);
_connector.setIdleTimeout(1000);
_server.addConnector(_connector);
_server.setHandler(new DumpHandler());
ErrorHandler eh = new ErrorHandler();
eh.setServer(_server);
_server.addBean(eh);
_server.start();
}
@AfterEach
public void destroy() throws Exception
{
_server.stop();
_server.join();
}
@Test
public void testSimple() throws Exception
{
String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" +
String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r \n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
assertThat(response, Matchers.containsString("local=5.6.7.8:222"));
assertThat(response, Matchers.containsString("remote=1.2.3.4:111"));
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertNull(response);
}
@Test
public void testIPv6() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
public void testBadChar(RequestProcessor p) throws Exception
{
String request = "PROXY\tTCP 1.2.3.4 5.6.7.8 111 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertNull(response);
}
@ParameterizedTest
@MethodSource("requestProcessors")
public void testBadPort(RequestProcessor p) throws Exception
{
try (StacklessLogging stackless = new StacklessLogging(ProxyConnectionFactory.class))
{
String request = "PROXY TCP 1.2.3.4 5.6.7.8 9999999999999 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertNull(response);
}
}
@ParameterizedTest
@MethodSource("requestProcessors")
public void testHttp(RequestProcessor p) throws Exception
{
String request =
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
}
@ParameterizedTest
@MethodSource("requestProcessors")
public void testIPv6(RequestProcessor p) throws Exception
{
Assumptions.assumeTrue(Net.isIpv6InterfaceAvailable());
String response = _connector.getResponse("PROXY TCP6 eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n" +
String request = "PROXY TCP6 eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
@ -94,83 +113,272 @@ public class ProxyConnectionTest
assertThat(response, Matchers.containsString("local=ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff:65535"));
}
@Test
public void testTooLong() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
public void testIPv6V2(RequestProcessor p) throws Exception
{
String response = _connector.getResponse("PROXY TOOLONG!!! eeee:eeee:eeee:eeee:0000:0000:0000:0000 ffff:ffff:ffff:ffff:0000:0000:0000:0000 65535 65535\r\n" +
Assumptions.assumeTrue(Net.isIpv6InterfaceAvailable());
String proxy =
// Preamble
"0D0A0D0A000D0A515549540A" +
// V2, PROXY
"21" +
// 0x1 : AF_INET6 0x1 : STREAM.
"21" +
// Address length is 2*16 + 2*2 = 36 bytes.
// length of remaining header (16+16+2+2 = 36)
"0024" +
// uint8_t src_addr[16]; uint8_t dst_addr[16]; uint16_t src_port; uint16_t dst_port;
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF" + // ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff
"EEEEEEEEEEEEEEEEEEEEEEEEEEEEEEEE" + // eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee
"3039" + // 12345
"1F90"; // 8080
String http = "GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = p.sendRequestWaitingForResponse(TypeUtil.fromHexString(proxy), http.getBytes(StandardCharsets.US_ASCII));
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
assertThat(response, Matchers.containsString("local=eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee:8080"));
assertThat(response, Matchers.containsString("remote=ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff:12345"));
}
@ParameterizedTest
@MethodSource("requestProcessors")
public void testMissingField(RequestProcessor p) throws Exception
{
String request = "PROXY TCP 1.2.3.4 5.6.7.8 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertNull(response);
}
@Test
public void testNotComplete() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
public void testNotComplete(RequestProcessor p) throws Exception
{
_connector.setIdleTimeout(100);
String response = _connector.getResponse("PROXY TIMEOUT");
String response = p.customize(connector -> connector.setIdleTimeout(100)).sendRequestWaitingForResponse("PROXY TIMEOUT");
assertNull(response);
}
@Test
public void testBadChar() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
public void testTooLong(RequestProcessor p) throws Exception
{
String response = _connector.getResponse("PROXY\tTCP 1.2.3.4 5.6.7.8 111 222\r\n" +
String request = "PROXY TOOLONG!!! eeee:eeee:eeee:eeee:0000:0000:0000:0000 ffff:ffff:ffff:ffff:0000:0000:0000:0000 65535 65535\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertNull(response);
}
@Test
public void testBadCRLF() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
void testSimple(RequestProcessor p) throws Exception
{
String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 111 222\r \n" +
String request = "PROXY TCP 1.2.3.4 5.6.7.8 111 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
assertNull(response);
"\n";
String response = p.sendRequestWaitingForResponse(request);
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
assertThat(response, Matchers.containsString("local=5.6.7.8:222"));
assertThat(response, Matchers.containsString("remote=1.2.3.4:111"));
}
@Test
public void testBadPort() throws Exception
@ParameterizedTest
@MethodSource("requestProcessors")
void testSimpleV2(RequestProcessor p) throws Exception
{
try (StacklessLogging stackless = new StacklessLogging(ProxyConnectionFactory.class))
String proxy =
// Preamble
"0D0A0D0A000D0A515549540A" +
// V2, PROXY
"21" +
// 0x1 : AF_INET 0x1 : STREAM.
"11" +
// Address length is 2*4 + 2*2 = 12 bytes.
// length of remaining header (4+4+2+2 = 12)
"000C" +
// uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port;
"C0A80001" + // 192.168.0.1
"7f000001" + // 127.0.0.1
"3039" + // 12345
"1F90"; // 8080
String http = "GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = p.sendRequestWaitingForResponse(TypeUtil.fromHexString(proxy), http.getBytes(StandardCharsets.US_ASCII));
assertThat(response, Matchers.containsString("HTTP/1.1 200"));
assertThat(response, Matchers.containsString("pathInfo=/path"));
assertThat(response, Matchers.containsString("local=127.0.0.1:8080"));
assertThat(response, Matchers.containsString("remote=192.168.0.1:12345"));
}
@ParameterizedTest
@MethodSource("requestProcessors")
void testMaxHeaderLengthV2(RequestProcessor p) throws Exception
{
p.customize((connector) ->
{
String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 9999999999999 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
assertNull(response);
ProxyConnectionFactory factory = (ProxyConnectionFactory)connector.getConnectionFactory("proxy");
factory.setMaxProxyHeader(11); // just one byte short
});
String proxy =
// Preamble
"0D0A0D0A000D0A515549540A" +
// V2, PROXY
"21" +
// 0x1 : AF_INET 0x1 : STREAM.
"11" +
// Address length is 2*4 + 2*2 = 12 bytes.
// length of remaining header (4+4+2+2 = 12)
"000C" +
// uint32_t src_addr; uint32_t dst_addr; uint16_t src_port; uint16_t dst_port;
"C0A80001" +
"7f000001" +
"3039" +
"1F90";
String http = "GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n";
String response = p.sendRequestWaitingForResponse(TypeUtil.fromHexString(proxy), http.getBytes(StandardCharsets.US_ASCII));
assertThat(response, Matchers.is(Matchers.nullValue()));
}
abstract static class RequestProcessor
{
protected LocalConnector _connector;
private Server _server;
public RequestProcessor()
{
_server = new Server();
HttpConnectionFactory http = new HttpConnectionFactory();
http.getHttpConfiguration().setRequestHeaderSize(1024);
http.getHttpConfiguration().setResponseHeaderSize(1024);
ProxyConnectionFactory proxy = new ProxyConnectionFactory(HttpVersion.HTTP_1_1.asString());
_connector = new LocalConnector(_server, null, null, null, 1, proxy, http);
_connector.setIdleTimeout(1000);
_server.addConnector(_connector);
_server.setHandler(new DumpHandler());
ErrorHandler eh = new ErrorHandler();
eh.setServer(_server);
_server.addBean(eh);
}
public RequestProcessor customize(Consumer<LocalConnector> consumer)
{
consumer.accept(_connector);
return this;
}
public final String sendRequestWaitingForResponse(String request) throws Exception
{
return sendRequestWaitingForResponse(request.getBytes(StandardCharsets.US_ASCII));
}
public final String sendRequestWaitingForResponse(byte[]... requests) throws Exception
{
try
{
_server.start();
return process(requests);
}
finally
{
destroy();
}
}
protected abstract String process(byte[]... requests) throws Exception;
private void destroy() throws Exception
{
_server.stop();
_server.join();
}
}
@Test
public void testMissingField() throws Exception
static Stream<Arguments> requestProcessors()
{
String response = _connector.getResponse("PROXY TCP 1.2.3.4 5.6.7.8 222\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
assertNull(response);
return Stream.of(
Arguments.of(new RequestProcessor()
{
@Override
public String process(byte[]... requests) throws Exception
{
LocalConnector.LocalEndPoint endPoint = _connector.connect();
for (byte[] request : requests)
{
endPoint.addInput(ByteBuffer.wrap(request));
}
return endPoint.getResponse();
}
@Override
public String toString()
{
return "All bytes at once";
}
}),
Arguments.of(new RequestProcessor()
{
@Override
public String process(byte[]... requests) throws Exception
{
LocalConnector.LocalEndPoint endPoint = _connector.connect();
for (byte[] request : requests)
{
for (byte b : request)
{
endPoint.addInput(ByteBuffer.wrap(new byte[]{b}));
}
}
return endPoint.getResponse();
}
@Override
public String toString()
{
return "Byte by byte";
}
})
);
}
@Test
public void testHTTP() throws Exception
{
String response = _connector.getResponse(
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +
"\n");
assertNull(response);
}
}