Support routing data through an HTTP proxy (#11891)

* Support routing data through an HTTP proxy

* Support routing data through an HTTP proxy

This adds the ability for the HttpClient to connect through an HTTP proxy.  We
augment the channel factory to check if it is supposed to be proxied and, if so,
we connect to the proxy host first, issue a CONNECT command through to the final
recipient host and *then* give the channel to the normal http client for usage.

* add docs

* address comments

Co-authored-by: imply-cheddar <86940447+imply-cheddar@users.noreply.github.com>
This commit is contained in:
Maytas Monsereenusorn 2021-11-09 17:24:06 -08:00 committed by GitHub
parent 6c196a5ea2
commit a36a41da73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 330 additions and 25 deletions

View File

@ -82,6 +82,7 @@ public class HttpClientConfig
private final int numConnections; private final int numConnections;
private final SSLContext sslContext; private final SSLContext sslContext;
private final HttpClientProxyConfig proxyConfig;
private final Duration readTimeout; private final Duration readTimeout;
private final Duration sslHandshakeTimeout; private final Duration sslHandshakeTimeout;
private final int bossPoolSize; private final int bossPoolSize;
@ -92,6 +93,7 @@ public class HttpClientConfig
private HttpClientConfig( private HttpClientConfig(
int numConnections, int numConnections,
SSLContext sslContext, SSLContext sslContext,
HttpClientProxyConfig proxyConfig,
Duration readTimeout, Duration readTimeout,
Duration sslHandshakeTimeout, Duration sslHandshakeTimeout,
int bossPoolSize, int bossPoolSize,
@ -102,6 +104,7 @@ public class HttpClientConfig
{ {
this.numConnections = numConnections; this.numConnections = numConnections;
this.sslContext = sslContext; this.sslContext = sslContext;
this.proxyConfig = proxyConfig;
this.readTimeout = readTimeout; this.readTimeout = readTimeout;
this.sslHandshakeTimeout = sslHandshakeTimeout; this.sslHandshakeTimeout = sslHandshakeTimeout;
this.bossPoolSize = bossPoolSize; this.bossPoolSize = bossPoolSize;
@ -120,6 +123,11 @@ public class HttpClientConfig
return sslContext; return sslContext;
} }
public HttpClientProxyConfig getProxyConfig()
{
return proxyConfig;
}
public Duration getReadTimeout() public Duration getReadTimeout()
{ {
return readTimeout; return readTimeout;
@ -154,6 +162,7 @@ public class HttpClientConfig
{ {
private int numConnections = 1; private int numConnections = 1;
private SSLContext sslContext = null; private SSLContext sslContext = null;
private HttpClientProxyConfig proxyConfig = null;
private Duration readTimeout = null; private Duration readTimeout = null;
private Duration sslHandshakeTimeout = null; private Duration sslHandshakeTimeout = null;
private int bossCount = DEFAULT_BOSS_COUNT; private int bossCount = DEFAULT_BOSS_COUNT;
@ -177,6 +186,12 @@ public class HttpClientConfig
return this; return this;
} }
public Builder withHttpProxyConfig(HttpClientProxyConfig config)
{
this.proxyConfig = config;
return this;
}
public Builder withReadTimeout(Duration readTimeout) public Builder withReadTimeout(Duration readTimeout)
{ {
this.readTimeout = readTimeout; this.readTimeout = readTimeout;
@ -212,6 +227,7 @@ public class HttpClientConfig
return new HttpClientConfig( return new HttpClientConfig(
numConnections, numConnections,
sslContext, sslContext,
proxyConfig,
readTimeout, readTimeout,
sslHandshakeTimeout, sslHandshakeTimeout,
bossCount, bossCount,

View File

@ -84,6 +84,7 @@ public class HttpClientInit
new ChannelResourceFactory( new ChannelResourceFactory(
createBootstrap(lifecycle, timer, config.getBossPoolSize(), config.getWorkerPoolSize()), createBootstrap(lifecycle, timer, config.getBossPoolSize(), config.getWorkerPoolSize()),
config.getSslContext(), config.getSslContext(),
config.getProxyConfig(),
timer, timer,
config.getSslHandshakeTimeout() == null ? -1 : config.getSslHandshakeTimeout().getMillis() config.getSslHandshakeTimeout() == null ? -1 : config.getSslHandshakeTimeout().getMillis()
), ),

View File

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.java.util.http.client;
import com.fasterxml.jackson.annotation.JsonProperty;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
public class HttpClientProxyConfig
{
@JsonProperty("host")
private String host;
@JsonProperty("port")
@Min(0)
@Max(65_535)
private int port;
@JsonProperty("user")
private String user;
@JsonProperty("password")
private String password;
public HttpClientProxyConfig()
{
}
public HttpClientProxyConfig(String host, int port, String user, String password)
{
this.host = host;
this.port = port;
this.user = user;
this.password = password;
}
public String getHost()
{
return host;
}
public int getPort()
{
return port;
}
public String getUser()
{
return user;
}
public String getPassword()
{
return password;
}
@SuppressWarnings("VariableNotUsedInsideIf")
@Override
public String toString()
{
return "HttpClientProxyConfig{" +
"proxyHost='" + host + '\'' +
", proxyPort=" + port +
", user='" + user + '\'' +
", password='" + ((password == null) ? "__is_null__" : "***") + '\'' +
'}';
}
}

View File

@ -164,12 +164,16 @@ public class Request
public Request setBasicAuthentication(String username, String password) public Request setBasicAuthentication(String username, String password)
{ {
final String base64Value = base64Encode(username + ":" + password); setHeader(HttpHeaders.Names.AUTHORIZATION, makeBasicAuthenticationString(username, password));
setHeader(HttpHeaders.Names.AUTHORIZATION, "Basic " + base64Value);
return this; return this;
} }
private String base64Encode(final String value) public static String makeBasicAuthenticationString(String username, String password)
{
return "Basic " + base64Encode(username + ":" + password);
}
private static String base64Encode(final String value)
{ {
final ChannelBufferFactory bufferFactory = HeapChannelBufferFactory.getInstance(); final ChannelBufferFactory bufferFactory = HeapChannelBufferFactory.getInstance();

View File

@ -22,6 +22,8 @@ package org.apache.druid.java.util.http.client.pool;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.http.client.HttpClientProxyConfig;
import org.apache.druid.java.util.http.client.Request;
import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelException; import org.jboss.netty.channel.ChannelException;
@ -31,7 +33,14 @@ import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.Channels; import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.handler.codec.http.DefaultHttpRequest;
import org.jboss.netty.handler.codec.http.HttpClientCodec;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpResponse;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import org.jboss.netty.handler.codec.http.HttpVersion;
import org.jboss.netty.handler.ssl.SslHandler; import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.util.Timer; import org.jboss.netty.util.Timer;
@ -50,21 +59,25 @@ public class ChannelResourceFactory implements ResourceFactory<String, ChannelFu
private static final Logger log = new Logger(ChannelResourceFactory.class); private static final Logger log = new Logger(ChannelResourceFactory.class);
private static final long DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(10); private static final long DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(10);
private static final String DRUID_PROXY_HANDLER = "druid_proxyHandler";
private final ClientBootstrap bootstrap; private final ClientBootstrap bootstrap;
private final SSLContext sslContext; private final SSLContext sslContext;
private final HttpClientProxyConfig proxyConfig;
private final Timer timer; private final Timer timer;
private final long sslHandshakeTimeout; private final long sslHandshakeTimeout;
public ChannelResourceFactory( public ChannelResourceFactory(
ClientBootstrap bootstrap, ClientBootstrap bootstrap,
SSLContext sslContext, SSLContext sslContext,
HttpClientProxyConfig proxyConfig,
Timer timer, Timer timer,
long sslHandshakeTimeout long sslHandshakeTimeout
) )
{ {
this.bootstrap = Preconditions.checkNotNull(bootstrap, "bootstrap"); this.bootstrap = Preconditions.checkNotNull(bootstrap, "bootstrap");
this.sslContext = sslContext; this.sslContext = sslContext;
this.proxyConfig = proxyConfig;
this.timer = timer; this.timer = timer;
this.sslHandshakeTimeout = sslHandshakeTimeout >= 0 ? sslHandshakeTimeout : DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS; this.sslHandshakeTimeout = sslHandshakeTimeout >= 0 ? sslHandshakeTimeout : DEFAULT_SSL_HANDSHAKE_TIMEOUT_MILLIS;
@ -88,7 +101,99 @@ public class ChannelResourceFactory implements ResourceFactory<String, ChannelFu
final String host = url.getHost(); final String host = url.getHost();
final int port = url.getPort() == -1 ? url.getDefaultPort() : url.getPort(); final int port = url.getPort() == -1 ? url.getDefaultPort() : url.getPort();
final ChannelFuture retVal; final ChannelFuture retVal;
final ChannelFuture connectFuture = bootstrap.connect(new InetSocketAddress(host, port)); final ChannelFuture connectFuture;
if (proxyConfig != null) {
final ChannelFuture proxyFuture = bootstrap.connect(
new InetSocketAddress(proxyConfig.getHost(), proxyConfig.getPort())
);
connectFuture = Channels.future(proxyFuture.getChannel());
final String proxyUri = StringUtils.format("%s:%d", host, port);
DefaultHttpRequest connectRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.CONNECT, proxyUri);
if (proxyConfig.getUser() != null) {
connectRequest.headers().add(
"Proxy-Authorization", Request.makeBasicAuthenticationString(
proxyConfig.getUser(), proxyConfig.getPassword()
)
);
}
proxyFuture.addListener(new ChannelFutureListener()
{
@Override
public void operationComplete(ChannelFuture f1)
{
if (f1.isSuccess()) {
final Channel channel = f1.getChannel();
channel.getPipeline().addLast(
DRUID_PROXY_HANDLER,
new SimpleChannelUpstreamHandler()
{
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e)
{
Object msg = e.getMessage();
final ChannelPipeline pipeline = ctx.getPipeline();
pipeline.remove(DRUID_PROXY_HANDLER);
if (msg instanceof HttpResponse) {
HttpResponse httpResponse = (HttpResponse) msg;
if (HttpResponseStatus.OK.equals(httpResponse.getStatus())) {
// When the HttpClientCodec sees the CONNECT response complete, it goes into a "done"
// mode which makes it just do nothing. Swap it with a new instance that will cover
// subsequent requests
pipeline.replace("codec", "codec", new HttpClientCodec());
connectFuture.setSuccess();
} else {
connectFuture.setFailure(
new ChannelException(
StringUtils.format(
"Got status[%s] from CONNECT request to proxy[%s]",
httpResponse.getStatus(),
proxyUri
)
)
);
}
} else {
connectFuture.setFailure(new ChannelException(StringUtils.format(
"Got message of type[%s], don't know what to do.", msg.getClass()
)));
}
}
}
);
channel.write(connectRequest).addListener(
new ChannelFutureListener()
{
@Override
public void operationComplete(ChannelFuture f2)
{
if (!f2.isSuccess()) {
connectFuture.setFailure(
new ChannelException(
StringUtils.format("Problem with CONNECT request to proxy[%s]", proxyUri), f2.getCause()
)
);
}
}
}
);
} else {
connectFuture.setFailure(
new ChannelException(
StringUtils.format("Problem connecting to proxy[%s]", proxyUri), f1.getCause()
)
);
}
}
});
} else {
connectFuture = bootstrap.connect(new InetSocketAddress(host, port));
}
if ("https".equals(url.getProtocol())) { if ("https".equals(url.getProtocol())) {
if (sslContext == null) { if (sslContext == null) {
@ -111,11 +216,9 @@ public class ChannelResourceFactory implements ResourceFactory<String, ChannelFu
// https://github.com/netty/netty/issues/160 // https://github.com/netty/netty/issues/160
sslHandler.setCloseOnSSLException(true); sslHandler.setCloseOnSSLException(true);
final ChannelPipeline pipeline = connectFuture.getChannel().getPipeline();
pipeline.addFirst("ssl", sslHandler);
final ChannelFuture handshakeFuture = Channels.future(connectFuture.getChannel()); final ChannelFuture handshakeFuture = Channels.future(connectFuture.getChannel());
pipeline.addLast("connectionErrorHandler", new SimpleChannelUpstreamHandler() connectFuture.getChannel().getPipeline().addLast(
"connectionErrorHandler", new SimpleChannelUpstreamHandler()
{ {
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e)
@ -133,7 +236,8 @@ public class ChannelResourceFactory implements ResourceFactory<String, ChannelFu
channel.close(); channel.close();
} }
} }
}); }
);
connectFuture.addListener( connectFuture.addListener(
new ChannelFutureListener() new ChannelFutureListener()
{ {
@ -141,6 +245,8 @@ public class ChannelResourceFactory implements ResourceFactory<String, ChannelFu
public void operationComplete(ChannelFuture f) public void operationComplete(ChannelFuture f)
{ {
if (f.isSuccess()) { if (f.isSuccess()) {
final ChannelPipeline pipeline = f.getChannel().getPipeline();
pipeline.addFirst("ssl", sslHandler);
sslHandler.handshake().addListener( sslHandler.handshake().addListener(
new ChannelFutureListener() new ChannelFutureListener()
{ {

View File

@ -52,6 +52,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
/** /**
* Tests with servers that are at least moderately well-behaving. * Tests with servers that are at least moderately well-behaving.
@ -113,6 +114,81 @@ public class FriendlyServersTest
} }
} }
@Test
public void testFriendlyProxyHttpServer() throws Exception
{
final AtomicReference<String> requestContent = new AtomicReference<>();
final ExecutorService exec = Executors.newSingleThreadExecutor();
final ServerSocket serverSocket = new ServerSocket(0);
exec.submit(
new Runnable()
{
@Override
public void run()
{
while (!Thread.currentThread().isInterrupted()) {
try (
Socket clientSocket = serverSocket.accept();
BufferedReader in = new BufferedReader(
new InputStreamReader(clientSocket.getInputStream(), StandardCharsets.UTF_8)
);
OutputStream out = clientSocket.getOutputStream()
) {
StringBuilder request = new StringBuilder();
String line;
while (!"".equals((line = in.readLine()))) {
request.append(line).append("\r\n");
}
requestContent.set(request.toString());
out.write("HTTP/1.1 200 OK\r\n\r\n".getBytes(StandardCharsets.UTF_8));
while (!in.readLine().equals("")) {
// skip lines
}
out.write("HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nhello!".getBytes(StandardCharsets.UTF_8));
}
catch (Exception e) {
Assert.fail(e.toString());
}
}
}
}
);
final Lifecycle lifecycle = new Lifecycle();
try {
final HttpClientConfig config = HttpClientConfig
.builder()
.withHttpProxyConfig(
new HttpClientProxyConfig("localhost", serverSocket.getLocalPort(), "bob", "sally")
)
.build();
final HttpClient client = HttpClientInit.createClient(config, lifecycle);
final StatusResponseHolder response = client
.go(
new Request(
HttpMethod.GET,
new URL("http://anotherHost:8080/")
),
StatusResponseHandler.getInstance()
).get();
Assert.assertEquals(200, response.getStatus().getCode());
Assert.assertEquals("hello!", response.getContent());
Assert.assertEquals(
"CONNECT anotherHost:8080 HTTP/1.1\r\nProxy-Authorization: Basic Ym9iOnNhbGx5\r\n",
requestContent.get()
);
}
finally {
exec.shutdownNow();
serverSocket.close();
lifecycle.stop();
}
}
@Test @Test
public void testCompressionCodecConfig() throws Exception public void testCompressionCodecConfig() throws Exception
{ {

View File

@ -356,6 +356,22 @@ druid.coordinator.cleanupMetadata.duty.killSupervisors.retainDuration=PT0M
druid.coordinator.cleanupMetadata.period=PT10S druid.coordinator.cleanupMetadata.period=PT10S
``` ```
### Routing data through a HTTP proxy for your extension
You can add the ability for the `HttpClient` of your extension to connect through an HTTP proxy.
To support proxy connection for your extension's HTTP client:
1. Add `HttpClientProxyConfig` as a `@JsonProperty` to the HTTP config class of your extension.
2. In the extension's module class, add `HttpProxyConfig` config to `HttpClientConfig`.
For example, where `config` variable is the extension's HTTP config from step 1:
```
final HttpClientConfig.Builder builder = HttpClientConfig
.builder()
.withNumConnections(1)
.withReadTimeout(config.getReadTimeout().toStandardDuration())
.withHttpProxyConfig(config.getProxyConfig());
```
### Bundle your extension with all the other Druid extensions ### Bundle your extension with all the other Druid extensions
When you do `mvn install`, Druid extensions will be packaged within the Druid tarball and `extensions` directory, which are both underneath `distribution/target/`. When you do `mvn install`, Druid extensions will be packaged within the Druid tarball and `extensions` directory, which are both underneath `distribution/target/`.