Fixes #4421 - HttpClient support for PROXY protocol.

Improved support for Type-Length-Value (TLV) objects.

Signed-off-by: Simone Bordet <simone.bordet@gmail.com>
This commit is contained in:
Simone Bordet 2019-12-17 23:26:28 +01:00
parent 129a51c7a2
commit bea7f1a5cf
2 changed files with 101 additions and 33 deletions

View File

@ -23,8 +23,8 @@ import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
@ -206,7 +206,7 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
InetSocketAddress local = endPoint.getLocalAddress();
InetSocketAddress remote = endPoint.getRemoteAddress();
boolean ipv4 = local.getAddress() instanceof Inet4Address;
tag = new Tag(Tag.Command.PROXY, ipv4 ? Tag.Family.INET4 : Tag.Family.INET6, Tag.Protocol.STREAM, local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort());
tag = new Tag(Tag.Command.PROXY, ipv4 ? Tag.Family.INET4 : Tag.Family.INET6, Tag.Protocol.STREAM, local.getAddress().getHostAddress(), local.getPort(), remote.getAddress().getHostAddress(), remote.getPort(), null);
}
return new ProxyProtocolConnectionV2(endPoint, executor, getClientConnectionFactory(), context, tag);
}
@ -223,7 +223,7 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
/**
* The PROXY V2 Tag typically used to "ping" the server.
*/
public static final Tag LOCAL = new Tag(Command.LOCAL, Family.UNSPEC, Protocol.UNSPEC, null, 0, null, 0);
public static final Tag LOCAL = new Tag(Command.LOCAL, Family.UNSPEC, Protocol.UNSPEC, null, 0, null, 0, null);
private Command command;
private Family family;
@ -232,7 +232,7 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
private int srcPort;
private String dstIP;
private int dstPort;
private Map<Integer, byte[]> vectors;
private List<TLV> tlvs;
/**
* <p>Creates a Tag whose metadata will be derived from the underlying EndPoint.</p>
@ -251,7 +251,20 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
*/
public Tag(String srcIP, int srcPort)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0);
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0, null);
}
/**
* <p>Creates a Tag with the given source metadata and Type-Length-Value (TLV) objects.</p>
* <p>The destination metadata will be derived from the underlying EndPoint.</p>
*
* @param srcIP the source IP address
* @param srcPort the source port
* @param tlvs the TLV objects
*/
public Tag(String srcIP, int srcPort, List<TLV> tlvs)
{
this(Command.PROXY, null, Protocol.STREAM, srcIP, srcPort, null, 0, tlvs);
}
/**
@ -264,8 +277,9 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
* @param srcPort the source port
* @param dstIP the destination IP address
* @param dstPort the destination port
* @param tlvs the TLV objects
*/
public Tag(Command command, Family family, Protocol protocol, String srcIP, int srcPort, String dstIP, int dstPort)
public Tag(Command command, Family family, Protocol protocol, String srcIP, int srcPort, String dstIP, int dstPort, List<TLV> tlvs)
{
this.command = command;
this.family = family;
@ -274,17 +288,7 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
this.srcPort = srcPort;
this.dstIP = dstIP;
this.dstPort = dstPort;
}
public void put(int type, byte[] data)
{
if (type < 0 || type > 255)
throw new IllegalArgumentException("Invalid type: " + type);
if (data != null && data.length > 65535)
throw new IllegalArgumentException("Invalid data length: " + data.length);
if (vectors == null)
vectors = new HashMap<>();
vectors.put(type, data);
this.tlvs = tlvs;
}
public Command getCommand()
@ -322,9 +326,9 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
return dstPort;
}
public Map<Integer, byte[]> getVectors()
public List<TLV> getTLVs()
{
return vectors != null ? vectors : Collections.emptyMap();
return tlvs;
}
@Override
@ -347,13 +351,14 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
Objects.equals(srcIP, that.srcIP) &&
srcPort == that.srcPort &&
Objects.equals(dstIP, that.dstIP) &&
dstPort == that.dstPort;
dstPort == that.dstPort &&
Objects.equals(tlvs, that.tlvs);
}
@Override
public int hashCode()
{
return Objects.hash(command, family, protocol, srcIP, srcPort, dstIP, dstPort);
return Objects.hash(command, family, protocol, srcIP, srcPort, dstIP, dstPort, tlvs);
}
public enum Command
@ -370,6 +375,51 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
{
UNSPEC, STREAM, DGRAM
}
public static class TLV
{
private final int type;
private final byte[] value;
public TLV(int type, byte[] value)
{
if (type < 0 || type > 255)
throw new IllegalArgumentException("Invalid type: " + type);
if (value != null && value.length > 65535)
throw new IllegalArgumentException("Invalid value length: " + value.length);
this.type = type;
this.value = Objects.requireNonNull(value);
}
public int getType()
{
return type;
}
public byte[] getValue()
{
return value;
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null || getClass() != obj.getClass())
return false;
TLV that = (TLV)obj;
return type == that.type && Arrays.equals(value, that.value);
}
@Override
public int hashCode()
{
int result = Objects.hash(type);
result = 31 * result + Arrays.hashCode(value);
return result;
}
}
}
}
@ -533,9 +583,9 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
capacity += 1; // family and protocol
capacity += 2; // length
capacity += 216; // max address length
Map<Integer, byte[]> vectors = tag.getVectors();
int vectorsLength = vectors.values().stream()
.mapToInt(data -> 1 + 2 + data.length)
List<V2.Tag.TLV> tlvs = tag.getTLVs();
int vectorsLength = tlvs == null ? 0 : tlvs.stream()
.mapToInt(tlv -> 1 + 2 + tlv.getValue().length)
.sum();
capacity += vectorsLength;
ByteBuffer buffer = ByteBuffer.allocateDirect(capacity);
@ -602,12 +652,15 @@ public abstract class ProxyProtocolClientConnectionFactory implements ClientConn
default:
throw new IllegalStateException();
}
for (Map.Entry<Integer, byte[]> entry : vectors.entrySet())
if (tlvs != null)
{
buffer.put(entry.getKey().byteValue());
byte[] data = entry.getValue();
buffer.putShort((short)data.length);
buffer.put(data);
for (V2.Tag.TLV tlv : tlvs)
{
buffer.put((byte)tlv.getType());
byte[] data = tlv.getValue();
buffer.putShort((short)data.length);
buffer.put(data);
}
}
buffer.flip();
endPoint.write(callback, buffer);

View File

@ -20,6 +20,7 @@ package org.eclipse.jetty.client;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import javax.servlet.http.HttpServletRequest;
@ -174,7 +175,8 @@ public class HttpClientProxyProtocolTest
EndPoint endPoint = jettyRequest.getHttpChannel().getEndPoint();
assertTrue(endPoint instanceof ProxyConnectionFactory.ProxyEndPoint);
ProxyConnectionFactory.ProxyEndPoint proxyEndPoint = (ProxyConnectionFactory.ProxyEndPoint)endPoint;
assertEquals(tlsVersion, proxyEndPoint.getAttribute(ProxyConnectionFactory.TLS_VERSION));
if (target.equals("/tls_version"))
assertEquals(tlsVersion, proxyEndPoint.getAttribute(ProxyConnectionFactory.TLS_VERSION));
response.setContentType(MimeTypes.Type.TEXT_PLAIN.asString());
response.getOutputStream().print(request.getRemotePort());
}
@ -184,7 +186,6 @@ public class HttpClientProxyProtocolTest
int serverPort = connector.getLocalPort();
int clientPort = ThreadLocalRandom.current().nextInt(1024, 65536);
V2.Tag tag = new V2.Tag("127.0.0.1", clientPort);
int typeTLS = 0x20;
byte[] dataTLS = new byte[1 + 4 + (1 + 2 + tlsVersionBytes.length)];
dataTLS[0] = 0x01; // CLIENT_SSL
@ -192,13 +193,27 @@ public class HttpClientProxyProtocolTest
dataTLS[6] = 0x00; // Length, hi byte
dataTLS[7] = (byte)tlsVersionBytes.length; // Length, lo byte
System.arraycopy(tlsVersionBytes, 0, dataTLS, 8, tlsVersionBytes.length);
tag.put(typeTLS, dataTLS);
V2.Tag.TLV tlv = new V2.Tag.TLV(typeTLS, dataTLS);
V2.Tag tag = new V2.Tag("127.0.0.1", clientPort, Collections.singletonList(tlv));
ContentResponse response = client.newRequest("localhost", serverPort)
.path("/tls_version")
.tag(tag)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
// Make another request with the same address information, but different TLV.
V2.Tag.TLV tlv2 = new V2.Tag.TLV(0x01, "http/1.1".getBytes(StandardCharsets.UTF_8));
V2.Tag tag2 = new V2.Tag("127.0.0.1", clientPort, Collections.singletonList(tlv2));
response = client.newRequest("localhost", serverPort)
.tag(tag2)
.send();
assertEquals(HttpStatus.OK_200, response.getStatus());
assertEquals(String.valueOf(clientPort), response.getContentAsString());
// Make sure the two TLVs created two destinations.
assertEquals(2, client.getDestinations().size());
}
@Test