add scheme to AsyncQueryForwardingServlet (#4688)

* add scheme to AsyncQueryForwardingServlet

* add sslContext binding for Router
This commit is contained in:
Parag Jain 2017-08-28 17:03:43 -05:00 committed by Jonathan Wei
parent daf3c5f927
commit 594a66f3c0
4 changed files with 90 additions and 65 deletions

View File

@ -26,6 +26,7 @@ import io.druid.guice.JsonConfigProvider;
import io.druid.guice.annotations.Client;
import io.druid.guice.annotations.Global;
import io.druid.initialization.DruidModule;
import io.druid.server.router.Router;
import javax.net.ssl.SSLContext;
import java.util.List;
@ -46,5 +47,6 @@ public class SSLContextModule implements DruidModule
binder.bind(SSLContext.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Global.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Client.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Router.class).toProvider(SSLContextProvider.class);
}
}

View File

@ -28,6 +28,7 @@ import com.google.inject.Inject;
import com.google.inject.Provider;
import com.metamx.emitter.EmittingLogger;
import com.metamx.emitter.service.ServiceEmitter;
import io.druid.client.selector.Server;
import io.druid.guice.annotations.Json;
import io.druid.guice.annotations.Smile;
import io.druid.guice.http.DruidHttpClientConfig;
@ -72,6 +73,7 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
private static final String APPLICATION_SMILE = "application/smile";
private static final String HOST_ATTRIBUTE = "io.druid.proxy.to.host";
private static final String SCHEME_ATTRIBUTE = "io.druid.proxy.to.host.scheme";
private static final String QUERY_ATTRIBUTE = "io.druid.proxy.query";
private static final String OBJECTMAPPER_ATTRIBUTE = "io.druid.proxy.objectMapper";
@ -169,35 +171,31 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper;
request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper);
final String defaultHost = hostFinder.getDefaultHost();
request.setAttribute(HOST_ATTRIBUTE, defaultHost);
final Server defaultServer = hostFinder.getDefaultServer();
request.setAttribute(HOST_ATTRIBUTE, defaultServer.getHost());
request.setAttribute(SCHEME_ATTRIBUTE, defaultServer.getScheme());
final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2");
if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) {
// query cancellation request
for (final String host : hostFinder.getAllHosts()) {
for (final Server server: hostFinder.getAllServers()) {
// send query cancellation to all brokers this query may have gone to
// to keep the code simple, the proxy servlet will also send a request to one of the default brokers
if (!host.equals(defaultHost)) {
if (!server.getHost().equals(defaultServer.getHost())) {
// issue async requests
broadcastClient
.newRequest(rewriteURI(request, host))
.newRequest(rewriteURI(request, server.getScheme(), server.getHost()))
.method(HttpMethod.DELETE)
.timeout(CANCELLATION_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS)
.send(
new Response.CompleteListener()
{
@Override
public void onComplete(Result result)
{
if (result.isFailed()) {
log.warn(
result.getFailure(),
"Failed to forward cancellation request to [%s]",
host
);
}
result -> {
if (result.isFailed()) {
log.warn(
result.getFailure(),
"Failed to forward cancellation request to [%s]",
server.getHost()
);
}
}
);
@ -209,7 +207,9 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
try {
Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class);
if (inputQuery != null) {
request.setAttribute(HOST_ATTRIBUTE, hostFinder.getHost(inputQuery));
final Server server = hostFinder.getServer(inputQuery);
request.setAttribute(HOST_ATTRIBUTE, server.getHost());
request.setAttribute(SCHEME_ATTRIBUTE, server.getScheme());
if (inputQuery.getId() == null) {
inputQuery = inputQuery.withId(UUID.randomUUID().toString());
}
@ -289,19 +289,19 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
@Override
protected String rewriteTarget(HttpServletRequest request)
{
return rewriteURI(request, (String) request.getAttribute(HOST_ATTRIBUTE)).toString();
return rewriteURI(request, (String) request.getAttribute(SCHEME_ATTRIBUTE), (String) request.getAttribute(HOST_ATTRIBUTE)).toString();
}
protected URI rewriteURI(HttpServletRequest request, String host)
protected URI rewriteURI(HttpServletRequest request, String scheme, String host)
{
return makeURI(host, request.getRequestURI(), request.getQueryString());
return makeURI(scheme, host, request.getRequestURI(), request.getQueryString());
}
protected static URI makeURI(String host, String requestURI, String rawQueryString)
protected static URI makeURI(String scheme, String host, String requestURI, String rawQueryString)
{
try {
return new URI(
"http",
scheme,
host,
requestURI,
rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"),

View File

@ -19,8 +19,6 @@
package io.druid.server.router;
import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.inject.Inject;
import com.metamx.emitter.EmittingLogger;
import io.druid.client.selector.Server;
@ -31,6 +29,7 @@ import io.druid.query.Query;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/**
*/
@ -62,30 +61,14 @@ public class QueryHostFinder
return findServerInner(selected);
}
public Collection<String> getAllHosts()
public Collection<Server> getAllServers()
{
return FluentIterable
.from((Collection<List<Server>>) hostSelector.getAllBrokers().values())
.transformAndConcat(
new Function<List<Server>, Iterable<Server>>()
{
@Override
public Iterable<Server> apply(List<Server> input)
{
return input;
}
}
).transform(new Function<Server, String>()
{
@Override
public String apply(Server input)
{
return input.getHost();
}
}).toList();
return ((Collection<List<Server>>) hostSelector.getAllBrokers().values()).stream()
.flatMap(Collection::stream)
.collect(Collectors.toList());
}
public <T> String getHost(Query<T> query)
public <T> Server getServer(Query<T> query)
{
Server server = findServer(query);
@ -97,13 +80,12 @@ public class QueryHostFinder
throw new ISE("No server found for query[%s]", query);
}
final String host = server.getHost();
log.debug("Selected [%s]", host);
log.debug("Selected [%s]", server.getHost());
return host;
return server;
}
public String getDefaultHost()
public Server getDefaultServer()
{
Server server = findDefaultServer();
@ -115,7 +97,7 @@ public class QueryHostFinder
throw new ISE("No default server found!");
}
return server.getHost();
return server;
}
private Server findServerInner(final Pair<String, Server> selected)

View File

@ -29,7 +29,6 @@ import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.servlet.GuiceFilter;
import io.druid.common.utils.SocketUtil;
import io.druid.guice.GuiceInjectors;
import io.druid.guice.Jerseys;
@ -108,7 +107,9 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
public void configure(Binder binder)
{
JsonConfigProvider.bindInstance(
binder, Key.get(DruidNode.class, Self.class), new DruidNode("test", "localhost", null, null, new ServerConfig())
binder,
Key.get(DruidNode.class, Self.class),
new DruidNode("test", "localhost", null, null, new ServerConfig())
);
binder.bind(JettyServerInitializer.class).to(ProxyJettyServerInit.class).in(LazySingleton.class);
Jerseys.addResource(binder, SlowResource.class);
@ -197,24 +198,24 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
final QueryHostFinder hostFinder = new QueryHostFinder(null)
{
@Override
public String getHost(Query query)
public io.druid.client.selector.Server getServer(Query query)
{
return "localhost:" + node.getPlaintextPort();
return new TestServer("http", "localhost", node.getPlaintextPort());
}
@Override
public String getDefaultHost()
public io.druid.client.selector.Server getDefaultServer()
{
return "localhost:" + node.getPlaintextPort();
return new TestServer("http", "localhost", node.getPlaintextPort());
}
@Override
public Collection<String> getAllHosts()
public Collection<io.druid.client.selector.Server> getAllServers()
{
return ImmutableList.of(
"localhost:" + node.getPlaintextPort(),
"localhost:" + port1,
"localhost:" + port2
new TestServer("http", "localhost", node.getPlaintextPort()),
new TestServer("http", "localhost", port1),
new TestServer("http", "localhost", port2)
);
}
};
@ -241,9 +242,9 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
)
{
@Override
protected URI rewriteURI(HttpServletRequest request, String host)
protected URI rewriteURI(HttpServletRequest request, String scheme, String host)
{
String uri = super.rewriteURI(request, host).toString();
String uri = super.rewriteURI(request, scheme, host).toString();
if (uri.contains("/druid/v2")) {
return URI.create(uri.replace("/druid/v2", "/default"));
}
@ -272,7 +273,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
// test params
Assert.assertEquals(
new URI("http://localhost:1234/some/path?param=1"),
AsyncQueryForwardingServlet.makeURI("localhost:1234", "/some/path", "param=1")
AsyncQueryForwardingServlet.makeURI("http", "localhost:1234", "/some/path", "param=1")
);
// HttpServletRequest.getQueryString returns encoded form
@ -280,6 +281,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
Assert.assertEquals(
"http://[2a00:1450:4007:805::1007]:1234/some/path?param=1&param2=%E2%82%AC",
AsyncQueryForwardingServlet.makeURI(
"http",
HostAndPort.fromParts("2a00:1450:4007:805::1007", 1234).toString(),
"/some/path",
"param=1&param2=%E2%82%AC"
@ -289,7 +291,46 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
// test null query
Assert.assertEquals(
new URI("http://localhost/"),
AsyncQueryForwardingServlet.makeURI("localhost", "/", null)
AsyncQueryForwardingServlet.makeURI("http", "localhost", "/", null)
);
}
private static class TestServer implements io.druid.client.selector.Server
{
private final String scheme;
private final String address;
private final int port;
public TestServer(String scheme, String address, int port)
{
this.scheme = scheme;
this.address = address;
this.port = port;
}
@Override
public String getScheme()
{
return scheme;
}
@Override
public String getHost()
{
return address + ":" + port;
}
@Override
public String getAddress()
{
return address;
}
@Override
public int getPort()
{
return port;
}
}
}