Fixes #4421 - HttpClient support for PROXY protocol. (#4424)

* Fixes #4421 - HttpClient support for PROXY protocol.

Implemented support for the PROXY protocol in HttpClient.

Introduced Request.tag(Object) to tag requests that belong
to the same group (e.g. a client address) so that they can
generate a different destination.

The tag object may implement ClientConnectionFactory.Decorator
so that it can decorate the HttpDestination ClientConnectionFactory
and therefore work both with and without forward proxy configuration.

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Simone Bordet 2019-12-17 10:36:16 +01:00 committed by GitHub
parent 584e264b0b
commit 129a51c7a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1052 additions and 31 deletions

View File

@ -530,16 +530,29 @@ public class HttpClient extends ContainerLifeCycle
}
protected HttpDestination destinationFor(String scheme, String host, int port)
{
return resolveDestination(scheme, host, port, null);
}
protected HttpDestination resolveDestination(String scheme, String host, int port, Object tag)
{
Origin origin = createOrigin(scheme, host, port, tag);
return resolveDestination(origin);
}
protected Origin createOrigin(String scheme, String host, int port, Object tag)
{
if (!HttpScheme.HTTP.is(scheme) && !HttpScheme.HTTPS.is(scheme) &&
!HttpScheme.WS.is(scheme) && !HttpScheme.WSS.is(scheme))
throw new IllegalArgumentException("Invalid protocol " + scheme);
scheme = scheme.toLowerCase(Locale.ENGLISH);
host = host.toLowerCase(Locale.ENGLISH);
port = normalizePort(scheme, port);
return new Origin(scheme, host, port, tag);
}
Origin origin = new Origin(scheme, host, port);
protected HttpDestination resolveDestination(Origin origin)
{
return destinations.computeIfAbsent(origin, o ->
{
HttpDestination newDestination = getTransport().newHttpDestination(o);
@ -566,7 +579,7 @@ public class HttpClient extends ContainerLifeCycle
protected void send(final HttpRequest request, List<Response.ResponseListener> listeners)
{
HttpDestination destination = destinationFor(request.getScheme(), request.getHost(), request.getPort());
HttpDestination destination = resolveDestination(request.getScheme(), request.getHost(), request.getPort(), request.getTag());
destination.send(request, listeners);
}

View File

@ -94,6 +94,9 @@ public abstract class HttpDestination extends ContainerLifeCycle implements Dest
if (isSecure())
connectionFactory = newSslClientConnectionFactory(null, connectionFactory);
}
Object tag = origin.getTag();
if (tag instanceof ClientConnectionFactory.Decorator)
connectionFactory = ((ClientConnectionFactory.Decorator)tag).apply(connectionFactory);
this.connectionFactory = connectionFactory;
String host = HostPort.normalizeHost(getHost());

View File

@ -72,7 +72,7 @@ public class HttpProxy extends ProxyConfiguration.Proxy
return URI.create(new Origin(scheme, getAddress()).asString());
}
private class HttpProxyClientConnectionFactory implements ClientConnectionFactory
private static class HttpProxyClientConnectionFactory implements ClientConnectionFactory
{
private final ClientConnectionFactory connectionFactory;
@ -127,7 +127,7 @@ public class HttpProxy extends ProxyConfiguration.Proxy
* tunnel after the TCP connection is succeeded, and needs to notify
* the nested promise when the tunnel is established (or failed).</p>
*/
private class CreateTunnelPromise implements Promise<Connection>
private static class CreateTunnelPromise implements Promise<Connection>
{
private final ClientConnectionFactory connectionFactory;
private final EndPoint endPoint;
@ -233,7 +233,7 @@ public class HttpProxy extends ProxyConfiguration.Proxy
}
}
private class ProxyConnection implements Connection
private static class ProxyConnection implements Connection
{
private final Destination destination;
private final Connection connection;
@ -272,7 +272,7 @@ public class HttpProxy extends ProxyConfiguration.Proxy
}
}
private class TunnelPromise implements Promise<Connection>
private static class TunnelPromise implements Promise<Connection>
{
private final Request request;
private final Response.CompleteListener listener;

View File

@ -87,6 +87,7 @@ public class HttpRequest implements Request
private List<RequestListener> requestListeners;
private BiFunction<Request, Request, Response.CompleteListener> pushListener;
private Supplier<HttpFields> trailers;
private Object tag;
protected HttpRequest(HttpClient client, HttpConversation conversation, URI uri)
{
@ -313,6 +314,19 @@ public class HttpRequest implements Request
return this;
}
@Override
public Request tag(Object tag)
{
this.tag = tag;
return this;
}
@Override
public Object getTag()
{
return tag;
}
@Override
public Request attribute(String name, Object value)
{

View File

@ -26,16 +26,28 @@ public class Origin
{
private final String scheme;
private final Address address;
private final Object tag;
public Origin(String scheme, String host, int port)
{
this(scheme, new Address(host, port));
this(scheme, host, port, null);
}
public Origin(String scheme, String host, int port, Object tag)
{
this(scheme, new Address(host, port), tag);
}
public Origin(String scheme, Address address)
{
this(scheme, address, null);
}
public Origin(String scheme, Address address, Object tag)
{
this.scheme = Objects.requireNonNull(scheme);
this.address = address;
this.tag = tag;
}
public String getScheme()
@ -48,6 +60,11 @@ public class Origin
return address;
}
public Object getTag()
{
return tag;
}
public String asString()
{
StringBuilder result = new StringBuilder();
@ -63,14 +80,23 @@ public class Origin
if (obj == null || getClass() != obj.getClass())
return false;
Origin that = (Origin)obj;
return scheme.equals(that.scheme) && address.equals(that.address);
return scheme.equals(that.scheme) &&
address.equals(that.address) &&
Objects.equals(tag, that.tag);
}
@Override
public int hashCode()
{
int result = scheme.hashCode();
result = 31 * result + address.hashCode();
return Objects.hash(scheme, address, tag);
}
@Override
public String toString()
{
String result = asString();
if (tag != null)
result += "[tag=" + tag + "]";
return result;
}

View File

@ -61,7 +61,7 @@ public class ProxyConfiguration
public abstract static class Proxy
{
// TO use IPAddress Map
// TODO use InetAddressSet? Or IncludeExcludeSet?
private final Set<String> included = new HashSet<>();
private final Set<String> excluded = new HashSet<>();
private final Origin.Address address;

View File

@ -0,0 +1,621 @@
//
// ========================================================================
// Copyright (c) 1995-2019 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.client;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import org.eclipse.jetty.io.AbstractConnection;
import org.eclipse.jetty.io.ClientConnectionFactory;
import org.eclipse.jetty.io.Connection;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.Promise;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
/**
* <p>ClientConnectionFactory for the
* <a href="http://www.haproxy.org/download/2.1/doc/proxy-protocol.txt">PROXY protocol</a>.</p>
* <p>Use the {@link V1} or {@link V2} versions of this class to specify what version of the
* PROXY protocol you want to use.</p>
*/
public abstract class ProxyProtocolClientConnectionFactory implements ClientConnectionFactory
{
/**
* A ClientConnectionFactory for the PROXY protocol version 1.
*/
public static class V1 extends ProxyProtocolClientConnectionFactory
{
public V1(ClientConnectionFactory factory)
{
super(factory);
}
@Override
protected ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint, Map<String, Object> context)
{
HttpDestination destination = (HttpDestination)context.get(HttpClientTransport.HTTP_DESTINATION_CONTEXT_KEY);
Executor executor = destination.getHttpClient().getExecutor();
Tag tag = (Tag)destination.getOrigin().getTag();
if (tag == null)
{
InetSocketAddress local = endPoint.getLocalAddress();
InetSocketAddress remote = endPoint.getRemoteAddress();
boolean ipv4 = local.getAddress() instanceof Inet4Address;
tag = new Tag(ipv4 ? "TCP4" : "TCP6", local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort());
}
return new ProxyProtocolConnectionV1(endPoint, executor, getClientConnectionFactory(), context, tag);
}
/**
* <p>PROXY protocol version 1 metadata holder to be used in conjunction
* with {@link org.eclipse.jetty.client.api.Request#tag(Object)}.</p>
* <p>Instances of this class are associated to a destination so that
* all connections of that destination will initiate the communication
* with the PROXY protocol version 1 bytes specified by this metadata.</p>
*/
public static class Tag implements ClientConnectionFactory.Decorator
{
/**
* The PROXY V1 Tag typically used to "ping" the server.
*/
public static final Tag UNKNOWN = new Tag("UNKNOWN", null, 0, null, 0);
private final String family;
private final String srcIP;
private final int srcPort;
private final String dstIP;
private final int dstPort;
/**
* <p>Creates a Tag whose metadata will be derived from the underlying EndPoint.</p>
*/
public Tag()
{
this(null, 0);
}
/**
* <p>Creates a Tag with the given source metadata.</p>
* <p>The destination metadata will be derived from the underlying EndPoint.</p>
*
* @param srcIP the source IP address
* @param srcPort the source port
*/
public Tag(String srcIP, int srcPort)
{
this(null, srcIP, srcPort, null, 0);
}
/**
* <p>Creates a Tag with the given metadata.</p>
*
* @param family the protocol family
* @param srcIP the source IP address
* @param srcPort the source port
* @param dstIP the destination IP address
* @param dstPort the destination port
*/
public Tag(String family, String srcIP, int srcPort, String dstIP, int dstPort)
{
this.family = family;
this.srcIP = srcIP;
this.srcPort = srcPort;
this.dstIP = dstIP;
this.dstPort = dstPort;
}
public String getFamily()
{
return family;
}
public String getSourceAddress()
{
return srcIP;
}
public int getSourcePort()
{
return srcPort;
}
public String getDestinationAddress()
{
return dstIP;
}
public int getDestinationPort()
{
return dstPort;
}
@Override
public ClientConnectionFactory apply(ClientConnectionFactory factory)
{
return new V1(factory);
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
Tag that = (Tag)obj;
return Objects.equals(family, that.family) &&
Objects.equals(srcIP, that.srcIP) &&
srcPort == that.srcPort &&
Objects.equals(dstIP, that.dstIP) &&
dstPort == that.dstPort;
}
@Override
public int hashCode()
{
return Objects.hash(family, srcIP, srcPort, dstIP, dstPort);
}
}
}
/**
* A ClientConnectionFactory for the PROXY protocol version 2.
*/
public static class V2 extends ProxyProtocolClientConnectionFactory
{
public V2(ClientConnectionFactory factory)
{
super(factory);
}
@Override
protected ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint, Map<String, Object> context)
{
HttpDestination destination = (HttpDestination)context.get(HttpClientTransport.HTTP_DESTINATION_CONTEXT_KEY);
Executor executor = destination.getHttpClient().getExecutor();
Tag tag = (Tag)destination.getOrigin().getTag();
if (tag == null)
{
InetSocketAddress local = endPoint.getLocalAddress();
InetSocketAddress remote = endPoint.getRemoteAddress();
boolean ipv4 = local.getAddress() instanceof Inet4Address;
tag = new Tag(Tag.Command.PROXY, ipv4 ? Tag.Family.INET4 : Tag.Family.INET6, Tag.Protocol.STREAM, local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort());
}
return new ProxyProtocolConnectionV2(endPoint, executor, getClientConnectionFactory(), context, tag);
}
/**
* <p>PROXY protocol version 2 metadata holder to be used in conjunction
* with {@link org.eclipse.jetty.client.api.Request#tag(Object)}.</p>
* <p>Instances of this class are associated to a destination so that
* all connections of that destination will initiate the communication
* with the PROXY protocol version 2 bytes specified by this metadata.</p>
*/
public static class Tag implements ClientConnectionFactory.Decorator
{
/**
* The PROXY V2 Tag typically used to "ping" the server.
*/
public static final Tag LOCAL = new Tag(Command.LOCAL, Family.UNSPEC, Protocol.UNSPEC, null, 0, null, 0);
private Command command;
private Family family;
private Protocol protocol;
private String srcIP;
private int srcPort;
private String dstIP;
private int dstPort;
private Map<Integer, byte[]> vectors;
/**
* <p>Creates a Tag whose metadata will be derived from the underlying EndPoint.</p>
*/
public Tag()
{
this(null, 0);
}
/**
* <p>Creates a Tag with the given source metadata.</p>
* <p>The destination metadata will be derived from the underlying EndPoint.</p>
*
* @param srcIP the source IP address
* @param srcPort the source port
*/
public Tag(String srcIP, int srcPort)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0);
}
/**
* <p>Creates a Tag with the given metadata.</p>
*
* @param command the LOCAL or PROXY command
* @param family the protocol family
* @param protocol the protocol type
* @param srcIP the source IP address
* @param srcPort the source port
* @param dstIP the destination IP address
* @param dstPort the destination port
*/
public Tag(Command command, Family family, Protocol protocol, String srcIP, int srcPort, String dstIP, int dstPort)
{
this.command = command;
this.family = family;
this.protocol = protocol;
this.srcIP = srcIP;
this.srcPort = srcPort;
this.dstIP = dstIP;
this.dstPort = dstPort;
}
public void put(int type, byte[] data)
{
if (type < 0 || type > 255)
throw new IllegalArgumentException("Invalid type: " + type);
if (data != null && data.length > 65535)
throw new IllegalArgumentException("Invalid data length: " + data.length);
if (vectors == null)
vectors = new HashMap<>();
vectors.put(type, data);
}
public Command getCommand()
{
return command;
}
public Family getFamily()
{
return family;
}
public Protocol getProtocol()
{
return protocol;
}
public String getSourceAddress()
{
return srcIP;
}
public int getSourcePort()
{
return srcPort;
}
public String getDestinationAddress()
{
return dstIP;
}
public int getDestinationPort()
{
return dstPort;
}
public Map<Integer, byte[]> getVectors()
{
return vectors != null ? vectors : Collections.emptyMap();
}
@Override
public ClientConnectionFactory apply(ClientConnectionFactory factory)
{
return new V2(factory);
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
Tag that = (Tag)obj;
return command == that.command &&
family == that.family &&
protocol == that.protocol &&
Objects.equals(srcIP, that.srcIP) &&
srcPort == that.srcPort &&
Objects.equals(dstIP, that.dstIP) &&
dstPort == that.dstPort;
}
@Override
public int hashCode()
{
return Objects.hash(command, family, protocol, srcIP, srcPort, dstIP, dstPort);
}
public enum Command
{
LOCAL, PROXY
}
public enum Family
{
UNSPEC, INET4, INET6, UNIX
}
public enum Protocol
{
UNSPEC, STREAM, DGRAM
}
}
}
private final ClientConnectionFactory factory;
private ProxyProtocolClientConnectionFactory(ClientConnectionFactory factory)
{
this.factory = factory;
}
public ClientConnectionFactory getClientConnectionFactory()
{
return factory;
}
@Override
public Connection newConnection(EndPoint endPoint, Map<String, Object> context)
{
ProxyProtocolConnection connection = newProxyProtocolConnection(endPoint, context);
return customize(connection, context);
}
protected abstract ProxyProtocolConnection newProxyProtocolConnection(EndPoint endPoint, Map<String, Object> context);
private abstract static class ProxyProtocolConnection extends AbstractConnection implements Callback
{
protected static final Logger LOG = Log.getLogger(ProxyProtocolConnection.class);
private final ClientConnectionFactory factory;
private final Map<String, Object> context;
private ProxyProtocolConnection(EndPoint endPoint, Executor executor, ClientConnectionFactory factory, Map<String, Object> context)
{
super(endPoint, executor);
this.factory = factory;
this.context = context;
}
@Override
public void onOpen()
{
super.onOpen();
writePROXYBytes(getEndPoint(), this);
}
protected abstract void writePROXYBytes(EndPoint endPoint, Callback callback);
@Override
public void succeeded()
{
try
{
EndPoint endPoint = getEndPoint();
Connection connection = factory.newConnection(endPoint, context);
if (LOG.isDebugEnabled())
LOG.debug("Written PROXY line, upgrading to {}", connection);
endPoint.upgrade(connection);
}
catch (Throwable x)
{
failed(x);
}
}
@Override
public void failed(Throwable x)
{
close();
Promise<?> promise = (Promise<?>)context.get(HttpClientTransport.HTTP_CONNECTION_PROMISE_CONTEXT_KEY);
promise.failed(x);
}
@Override
public InvocationType getInvocationType()
{
return InvocationType.NON_BLOCKING;
}
@Override
public void onFillable()
{
}
}
private static class ProxyProtocolConnectionV1 extends ProxyProtocolConnection
{
private final V1.Tag tag;
public ProxyProtocolConnectionV1(EndPoint endPoint, Executor executor, ClientConnectionFactory factory, Map<String, Object> context, V1.Tag tag)
{
super(endPoint, executor, factory, context);
this.tag = tag;
}
@Override
protected void writePROXYBytes(EndPoint endPoint, Callback callback)
{
try
{
InetSocketAddress localAddress = endPoint.getLocalAddress();
InetSocketAddress remoteAddress = endPoint.getRemoteAddress();
String family = tag.getFamily();
String srcIP = tag.getSourceAddress();
int srcPort = tag.getSourcePort();
String dstIP = tag.getDestinationAddress();
int dstPort = tag.getDestinationPort();
if (family == null)
family = localAddress.getAddress() instanceof Inet4Address ? "TCP4" : "TCP6";
family = family.toUpperCase(Locale.ENGLISH);
boolean unknown = family.equals("UNKNOWN");
StringBuilder builder = new StringBuilder(64);
builder.append("PROXY ").append(family);
if (!unknown)
{
if (srcIP == null)
srcIP = localAddress.getAddress().getHostAddress();
builder.append(" ").append(srcIP);
if (dstIP == null)
dstIP = remoteAddress.getAddress().getHostAddress();
builder.append(" ").append(dstIP);
if (srcPort <= 0)
srcPort = localAddress.getPort();
builder.append(" ").append(srcPort);
if (dstPort <= 0)
dstPort = remoteAddress.getPort();
builder.append(" ").append(dstPort);
}
builder.append("\r\n");
String line = builder.toString();
if (LOG.isDebugEnabled())
LOG.debug("Writing PROXY bytes: {}", line.trim());
ByteBuffer buffer = ByteBuffer.wrap(line.getBytes(StandardCharsets.US_ASCII));
endPoint.write(callback, buffer);
}
catch (Throwable x)
{
callback.failed(x);
}
}
}
private static class ProxyProtocolConnectionV2 extends ProxyProtocolConnection
{
private static final byte[] MAGIC = {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
private final V2.Tag tag;
public ProxyProtocolConnectionV2(EndPoint endPoint, Executor executor, ClientConnectionFactory factory, Map<String, Object> context, V2.Tag tag)
{
super(endPoint, executor, factory, context);
this.tag = tag;
}
@Override
protected void writePROXYBytes(EndPoint endPoint, Callback callback)
{
try
{
int capacity = MAGIC.length;
capacity += 1; // version and command
capacity += 1; // family and protocol
capacity += 2; // length
capacity += 216; // max address length
Map<Integer, byte[]> vectors = tag.getVectors();
int vectorsLength = vectors.values().stream()
.mapToInt(data -> 1 + 2 + data.length)
.sum();
capacity += vectorsLength;
ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
buffer.put(MAGIC);
V2.Tag.Command command = tag.getCommand();
int versionAndCommand = (2 << 4) | (command.ordinal() & 0x0F);
buffer.put((byte)versionAndCommand);
V2.Tag.Family family = tag.getFamily();
String srcAddr = tag.getSourceAddress();
if (srcAddr == null)
srcAddr = endPoint.getLocalAddress().getAddress().getHostAddress();
int srcPort = tag.getSourcePort();
if (srcPort <= 0)
srcPort = endPoint.getLocalAddress().getPort();
if (family == null)
family = InetAddress.getByName(srcAddr) instanceof Inet4Address ? V2.Tag.Family.INET4 : V2.Tag.Family.INET6;
V2.Tag.Protocol protocol = tag.getProtocol();
if (protocol == null)
protocol = V2.Tag.Protocol.STREAM;
int familyAndProtocol = (family.ordinal() << 4) | protocol.ordinal();
buffer.put((byte)familyAndProtocol);
int length = 0;
switch (family)
{
case UNSPEC:
break;
case INET4:
length = 12;
break;
case INET6:
length = 36;
break;
case UNIX:
length = 216;
break;
default:
throw new IllegalStateException();
}
length += vectorsLength;
buffer.putShort((short)length);
String dstAddr = tag.getDestinationAddress();
if (dstAddr == null)
dstAddr = endPoint.getRemoteAddress().getAddress().getHostAddress();
int dstPort = tag.getDestinationPort();
if (dstPort <= 0)
dstPort = endPoint.getRemoteAddress().getPort();
switch (family)
{
case UNSPEC:
break;
case INET4:
case INET6:
buffer.put(InetAddress.getByName(srcAddr).getAddress());
buffer.put(InetAddress.getByName(dstAddr).getAddress());
buffer.putShort((short)srcPort);
buffer.putShort((short)dstPort);
break;
case UNIX:
int position = buffer.position();
buffer.put(srcAddr.getBytes(StandardCharsets.US_ASCII));
buffer.position(position + 108);
buffer.put(dstAddr.getBytes(StandardCharsets.US_ASCII));
break;
default:
throw new IllegalStateException();
}
for (Map.Entry<Integer, byte[]> entry : vectors.entrySet())
{
buffer.put(entry.getKey().byteValue());
byte[] data = entry.getValue();
buffer.putShort((short)data.length);
buffer.put(data);
}
buffer.flip();
endPoint.write(callback, buffer);
}
catch (Throwable x)
{
callback.failed(x);
}
}
}
}

View File

@ -180,6 +180,28 @@ public interface Request
*/
Request cookie(HttpCookie cookie);
/**
* <p>Tags this request with the given metadata tag.</p>
* <p>Each different tag will create a different destination,
* even if the destination origin is the same.</p>
* <p>This is particularly useful in proxies, where requests
* for the same origin but from different clients may be tagged
* with client's metadata (e.g. the client remote address).</p>
* <p>The tag metadata class must correctly implement
* {@link Object#hashCode()} and {@link Object#equals(Object)}
* so that it can be used, along with the origin, to identify
* a destination.</p>
*
* @param tag the metadata to tag the request with
* @return this request object
*/
Request tag(Object tag);
/**
* @return the metadata this request has been tagged with
*/
Object getTag();
/**
* @param name the name of the attribute
* @param value the value of the attribute

View File

@ -0,0 +1,264 @@
//
// ========================================================================
// Copyright (c) 1995-2019 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// All rights reserved. This program and the accompanying materials
// are made available under the terms of the Eclipse Public License v1.0
// and Apache License v2.0 which accompanies this distribution.
//
// The Eclipse Public License is available at
// http://www.eclipse.org/legal/epl-v10.html
//
// The Apache License v2.0 is available at
// http://www.opensource.org/licenses/apache2.0.php
//
// You may elect to redistribute this code under either of these licenses.
// ========================================================================
//
package org.eclipse.jetty.client;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.api.Destination;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.MimeTypes;
import org.eclipse.jetty.io.EndPoint;
import org.eclipse.jetty.server.Handler;
import org.eclipse.jetty.server.HttpConnectionFactory;
import org.eclipse.jetty.server.ProxyConnectionFactory;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import static org.eclipse.jetty.client.ProxyProtocolClientConnectionFactory.V1;
import static org.eclipse.jetty.client.ProxyProtocolClientConnectionFactory.V2;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class HttpClientProxyProtocolTest
{
private Server server;
private ServerConnector connector;
private HttpClient client;
private void startServer(Handler handler) throws Exception
{
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
HttpConnectionFactory http = new HttpConnectionFactory();
ProxyConnectionFactory proxy = new ProxyConnectionFactory(http.getProtocol());
connector = new ServerConnector(server, 1, 1, proxy, http);
server.addConnector(connector);
server.setHandler(handler);
server.start();
}
private void startClient() throws Exception
{
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
client = new HttpClient();
client.setExecutor(clientThreads);
client.setRemoveIdleDestinations(false);
client.start();
}
@AfterEach
public void dispose() throws Exception
{
if (server != null)
server.stop();
if (client != null)
client.stop();
}
@Test
public void testClientProxyProtocolV1() throws Exception
{
startServer(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException
{
response.setContentType(MimeTypes.Type.TEXT_PLAIN.asString());
response.getOutputStream().print(request.getRemotePort());
}
});
startClient();
int serverPort = connector.getLocalPort();
int clientPort = ThreadLocalRandom.current().nextInt(1024, 65536);
V1.Tag tag = new V1.Tag("127.0.0.1", clientPort);
ContentResponse response = client.newRequest("localhost", serverPort)
.tag(tag)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
}
@Test
public void testClientProxyProtocolV1Unknown() throws Exception
{
startServer(new EmptyServerHandler());
startClient();
ContentResponse response = client.newRequest("localhost", connector.getLocalPort())
.tag(V1.Tag.UNKNOWN)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
}
@Test
public void testClientProxyProtocolV2() throws Exception
{
startServer(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException
{
response.setContentType(MimeTypes.Type.TEXT_PLAIN.asString());
response.getOutputStream().print(request.getRemotePort());
}
});
startClient();
int serverPort = connector.getLocalPort();
int clientPort = ThreadLocalRandom.current().nextInt(1024, 65536);
V2.Tag tag = new V2.Tag("127.0.0.1", clientPort);
ContentResponse response = client.newRequest("localhost", serverPort)
.tag(tag)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
}
@Test
public void testClientProxyProtocolV2Local() throws Exception
{
startServer(new EmptyServerHandler());
startClient();
ContentResponse response = client.newRequest("localhost", connector.getLocalPort())
.tag(V2.Tag.LOCAL)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
}
@Test
public void testClientProxyProtocolV2WithVectors() throws Exception
{
String tlsVersion = "TLSv1.3";
byte[] tlsVersionBytes = tlsVersion.getBytes(StandardCharsets.US_ASCII);
startServer(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException
{
EndPoint endPoint = jettyRequest.getHttpChannel().getEndPoint();
assertTrue(endPoint instanceof ProxyConnectionFactory.ProxyEndPoint);
ProxyConnectionFactory.ProxyEndPoint proxyEndPoint = (ProxyConnectionFactory.ProxyEndPoint)endPoint;
assertEquals(tlsVersion, proxyEndPoint.getAttribute(ProxyConnectionFactory.TLS_VERSION));
response.setContentType(MimeTypes.Type.TEXT_PLAIN.asString());
response.getOutputStream().print(request.getRemotePort());
}
});
startClient();
int serverPort = connector.getLocalPort();
int clientPort = ThreadLocalRandom.current().nextInt(1024, 65536);
V2.Tag tag = new V2.Tag("127.0.0.1", clientPort);
int typeTLS = 0x20;
byte[] dataTLS = new byte[1 + 4 + (1 + 2 + tlsVersionBytes.length)];
dataTLS[0] = 0x01; // CLIENT_SSL
dataTLS[5] = 0x21; // SUBTYPE_SSL_VERSION
dataTLS[6] = 0x00; // Length, hi byte
dataTLS[7] = (byte)tlsVersionBytes.length; // Length, lo byte
System.arraycopy(tlsVersionBytes, 0, dataTLS, 8, tlsVersionBytes.length);
tag.put(typeTLS, dataTLS);
ContentResponse response = client.newRequest("localhost", serverPort)
.tag(tag)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
}
@Test
public void testProxyProtocolWrappingHTTPProxy() throws Exception
{
startServer(new EmptyServerHandler()
{
@Override
protected void service(String target, Request jettyRequest, HttpServletRequest request, HttpServletResponse response) throws IOException
{
response.setContentType(MimeTypes.Type.TEXT_PLAIN.asString());
response.getOutputStream().print(request.getRemotePort());
}
});
startClient();
int proxyPort = connector.getLocalPort();
int serverPort = proxyPort + 1; // Any port will do.
client.getProxyConfiguration().getProxies().add(new HttpProxy("localhost", proxyPort));
// We are simulating to be a HttpClient inside a proxy.
// The server is configured with the PROXY protocol to know the socket address of clients.
// The proxy receives a request from the client, and it extracts the client address.
int clientPort = ThreadLocalRandom.current().nextInt(1024, 65536);
V1.Tag tag = new V1.Tag("127.0.0.1", clientPort);
// The proxy maps the client address, then sends the request.
ContentResponse response = client.newRequest("localhost", serverPort)
.tag(tag)
.header(HttpHeader.CONNECTION, "close")
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
List<Destination> destinations = client.getDestinations();
assertEquals(1, destinations.size());
HttpDestination destination = (HttpDestination)destinations.get(0);
assertTrue(destination.getConnectionPool().isEmpty());
// The previous connection has been closed.
// Make another request from the same client address.
response = client.newRequest("localhost", serverPort)
.tag(tag)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
destinations = client.getDestinations();
assertEquals(1, destinations.size());
assertSame(destination, destinations.get(0));
// Make another request from a different client address.
int clientPort2 = clientPort + 1;
V1.Tag tag2 = new V1.Tag("127.0.0.1", clientPort2);
response = client.newRequest("localhost", serverPort)
.tag(tag2)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort2), response.getContentAsString());
destinations = client.getDestinations();
assertEquals(2, destinations.size());
}
}

View File

@ -44,4 +44,21 @@ public interface ClientConnectionFactory
connector.getBeans(Connection.Listener.class).forEach(connection::addListener);
return connection;
}
/**
* <p>Wraps another ClientConnectionFactory.</p>
* <p>This is typically done by protocols that send "preface" bytes with some metadata
* before other protocols. The metadata could be, for example, proxying information
* or authentication information.</p>
*/
interface Decorator
{
/**
* <p>Wraps the given {@code factory}.</p>
*
* @param factory the ClientConnectionFactory to wrap
* @return the wrapping ClientConnectionFactory
*/
ClientConnectionFactory apply(ClientConnectionFactory factory);
}
}

View File

@ -102,9 +102,10 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
public class ProxyProtocolV1orV2Connection extends AbstractConnection
{
// Only do a tiny read to figure out what PROXY version it is.
private final ByteBuffer _buffer = BufferUtil.allocate(16);
private final Connector _connector;
private final String _next;
private ByteBuffer _buffer = BufferUtil.allocate(16);
protected ProxyProtocolV1orV2Connection(EndPoint endp, Connector connector, String next)
{
@ -157,8 +158,11 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
return;
}
default:
{
LOG.warn("Not PROXY protocol for {}", getEndPoint());
close();
break;
}
}
}
catch (Throwable x)
@ -179,8 +183,8 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
private final Connector _connector;
private final String _next;
private final StringBuilder _builder = new StringBuilder();
private final String[] _field = new String[6];
private int _fields;
private final String[] _fields = new String[6];
private int _index;
private int _length;
protected ProxyProtocolV1Connection(EndPoint endp, Connector connector, String next, ByteBuffer buffer)
@ -201,16 +205,18 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
private boolean parse(ByteBuffer buffer)
{
// parse fields
// Parse fields
while (buffer.hasRemaining())
{
byte b = buffer.get();
if (_fields < 6)
if (_index < 6)
{
if (b == ' ' || b == '\r' && _fields == 5)
if (b == ' ' || b == '\r')
{
_field[_fields++] = _builder.toString();
_fields[_index++] = _builder.toString();
_builder.setLength(0);
if (b == '\r')
_index = 6;
}
else if (b < ' ')
{
@ -227,7 +233,7 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
{
if (b == '\n')
{
_fields = 7;
_index = 7;
return true;
}
@ -245,12 +251,12 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
try
{
ByteBuffer buffer = null;
while (_fields < 7)
while (_index < 7)
{
// Create a buffer that will not read too much data
// since once read it is impossible to push back for the
// real connection to read it.
int size = Math.max(1, SIZE[_fields] - _builder.length());
int size = Math.max(1, SIZE[_index] - _builder.length());
if (buffer == null || buffer.capacity() != size)
buffer = BufferUtil.allocate(size);
else
@ -282,22 +288,34 @@ public class ProxyConnectionFactory extends AbstractConnectionFactory
}
// Check proxy
if (!"PROXY".equals(_field[0]))
if (!"PROXY".equals(_fields[0]))
{
LOG.warn("Not PROXY protocol for {}", getEndPoint());
close();
return;
}
// Extract Addresses
InetSocketAddress remote = new InetSocketAddress(_field[2], Integer.parseInt(_field[4]));
InetSocketAddress local = new InetSocketAddress(_field[3], Integer.parseInt(_field[5]));
String srcIP = _fields[2];
String srcPort = _fields[4];
String dstIP = _fields[3];
String dstPort = _fields[5];
// If UNKNOWN, we must ignore the information sent, so use the EndPoint's.
boolean unknown = "UNKNOWN".equalsIgnoreCase(_fields[1]);
if (unknown)
{
srcIP = getEndPoint().getRemoteAddress().getAddress().getHostAddress();
srcPort = String.valueOf(getEndPoint().getRemoteAddress().getPort());
dstIP = getEndPoint().getLocalAddress().getAddress().getHostAddress();
dstPort = String.valueOf(getEndPoint().getLocalAddress().getPort());
}
InetSocketAddress remote = new InetSocketAddress(srcIP, Integer.parseInt(srcPort));
InetSocketAddress local = new InetSocketAddress(dstIP, Integer.parseInt(dstPort));
// Create the next protocol
ConnectionFactory connectionFactory = _connector.getConnectionFactory(_next);
if (connectionFactory == null)
{
LOG.warn("No Next protocol '{}' for {}", _next, getEndPoint());
LOG.warn("No next protocol '{}' for {}", _next, getEndPoint());
close();
return;
}

View File

@ -30,9 +30,6 @@ import org.junit.jupiter.api.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertNull;
/**
*
*/
public class ProxyConnectionTest
{
private Server _server;
@ -85,7 +82,7 @@ public class ProxyConnectionTest
public void testIPv6() throws Exception
{
Assumptions.assumeTrue(Net.isIpv6InterfaceAvailable());
String response = _connector.getResponse("PROXY UNKNOWN eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n" +
String response = _connector.getResponse("PROXY TCP6 eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff 65535 65535\r\n" +
"GET /path HTTP/1.1\n" +
"Host: server:80\n" +
"Connection: close\n" +

View File

@ -21,6 +21,7 @@ package org.eclipse.jetty.http.client;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
@ -35,6 +36,7 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jetty.client.api.ContentResponse;
import org.eclipse.jetty.client.api.Destination;
import org.eclipse.jetty.client.api.Response;
import org.eclipse.jetty.client.util.BytesContentProvider;
import org.eclipse.jetty.client.util.FutureResponseListener;
@ -650,6 +652,30 @@ public class HttpClientTest extends AbstractTest<TransportScenario>
assertEquals(0, response.getContent().length);
}
@ParameterizedTest
@ArgumentsSource(TransportProvider.class)
public void testOneDestinationPerUser(Transport transport) throws Exception
{
init(transport);
scenario.start(new EmptyServerHandler());
int runs = 4;
int users = 16;
for (int i = 0; i < runs; ++i)
{
for (int j = 0; j < users; ++j)
{
ContentResponse response = scenario.client.newRequest(scenario.newURI())
.tag(j)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
}
}
List<Destination> destinations = scenario.client.getDestinations();
assertEquals(users, destinations.size());
}
private void sleep(long time) throws IOException
{
try