add scheme to AsyncQueryForwardingServlet (#4688)

* add scheme to AsyncQueryForwardingServlet

* add sslContext binding for Router
This commit is contained in:
Parag Jain 2017-08-28 17:03:43 -05:00 committed by Jonathan Wei
parent daf3c5f927
commit 594a66f3c0
4 changed files with 90 additions and 65 deletions

View File

@ -26,6 +26,7 @@ import io.druid.guice.JsonConfigProvider;
import io.druid.guice.annotations.Client; import io.druid.guice.annotations.Client;
import io.druid.guice.annotations.Global; import io.druid.guice.annotations.Global;
import io.druid.initialization.DruidModule; import io.druid.initialization.DruidModule;
import io.druid.server.router.Router;
import javax.net.ssl.SSLContext; import javax.net.ssl.SSLContext;
import java.util.List; import java.util.List;
@ -46,5 +47,6 @@ public class SSLContextModule implements DruidModule
binder.bind(SSLContext.class).toProvider(SSLContextProvider.class); binder.bind(SSLContext.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Global.class).toProvider(SSLContextProvider.class); binder.bind(SSLContext.class).annotatedWith(Global.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Client.class).toProvider(SSLContextProvider.class); binder.bind(SSLContext.class).annotatedWith(Client.class).toProvider(SSLContextProvider.class);
binder.bind(SSLContext.class).annotatedWith(Router.class).toProvider(SSLContextProvider.class);
} }
} }

View File

@ -28,6 +28,7 @@ import com.google.inject.Inject;
import com.google.inject.Provider; import com.google.inject.Provider;
import com.metamx.emitter.EmittingLogger; import com.metamx.emitter.EmittingLogger;
import com.metamx.emitter.service.ServiceEmitter; import com.metamx.emitter.service.ServiceEmitter;
import io.druid.client.selector.Server;
import io.druid.guice.annotations.Json; import io.druid.guice.annotations.Json;
import io.druid.guice.annotations.Smile; import io.druid.guice.annotations.Smile;
import io.druid.guice.http.DruidHttpClientConfig; import io.druid.guice.http.DruidHttpClientConfig;
@ -72,6 +73,7 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
private static final String APPLICATION_SMILE = "application/smile"; private static final String APPLICATION_SMILE = "application/smile";
private static final String HOST_ATTRIBUTE = "io.druid.proxy.to.host"; private static final String HOST_ATTRIBUTE = "io.druid.proxy.to.host";
private static final String SCHEME_ATTRIBUTE = "io.druid.proxy.to.host.scheme";
private static final String QUERY_ATTRIBUTE = "io.druid.proxy.query"; private static final String QUERY_ATTRIBUTE = "io.druid.proxy.query";
private static final String OBJECTMAPPER_ATTRIBUTE = "io.druid.proxy.objectMapper"; private static final String OBJECTMAPPER_ATTRIBUTE = "io.druid.proxy.objectMapper";
@ -169,35 +171,31 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper; final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper;
request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper); request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper);
final String defaultHost = hostFinder.getDefaultHost(); final Server defaultServer = hostFinder.getDefaultServer();
request.setAttribute(HOST_ATTRIBUTE, defaultHost); request.setAttribute(HOST_ATTRIBUTE, defaultServer.getHost());
request.setAttribute(SCHEME_ATTRIBUTE, defaultServer.getScheme());
final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2"); final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2");
if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) { if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) {
// query cancellation request // query cancellation request
for (final String host : hostFinder.getAllHosts()) { for (final Server server: hostFinder.getAllServers()) {
// send query cancellation to all brokers this query may have gone to // send query cancellation to all brokers this query may have gone to
// to keep the code simple, the proxy servlet will also send a request to one of the default brokers // to keep the code simple, the proxy servlet will also send a request to one of the default brokers
if (!host.equals(defaultHost)) { if (!server.getHost().equals(defaultServer.getHost())) {
// issue async requests // issue async requests
broadcastClient broadcastClient
.newRequest(rewriteURI(request, host)) .newRequest(rewriteURI(request, server.getScheme(), server.getHost()))
.method(HttpMethod.DELETE) .method(HttpMethod.DELETE)
.timeout(CANCELLATION_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS) .timeout(CANCELLATION_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS)
.send( .send(
new Response.CompleteListener() result -> {
{ if (result.isFailed()) {
@Override log.warn(
public void onComplete(Result result) result.getFailure(),
{ "Failed to forward cancellation request to [%s]",
if (result.isFailed()) { server.getHost()
log.warn( );
result.getFailure(),
"Failed to forward cancellation request to [%s]",
host
);
}
} }
} }
); );
@ -209,7 +207,9 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
try { try {
Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class); Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class);
if (inputQuery != null) { if (inputQuery != null) {
request.setAttribute(HOST_ATTRIBUTE, hostFinder.getHost(inputQuery)); final Server server = hostFinder.getServer(inputQuery);
request.setAttribute(HOST_ATTRIBUTE, server.getHost());
request.setAttribute(SCHEME_ATTRIBUTE, server.getScheme());
if (inputQuery.getId() == null) { if (inputQuery.getId() == null) {
inputQuery = inputQuery.withId(UUID.randomUUID().toString()); inputQuery = inputQuery.withId(UUID.randomUUID().toString());
} }
@ -289,19 +289,19 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
@Override @Override
protected String rewriteTarget(HttpServletRequest request) protected String rewriteTarget(HttpServletRequest request)
{ {
return rewriteURI(request, (String) request.getAttribute(HOST_ATTRIBUTE)).toString(); return rewriteURI(request, (String) request.getAttribute(SCHEME_ATTRIBUTE), (String) request.getAttribute(HOST_ATTRIBUTE)).toString();
} }
protected URI rewriteURI(HttpServletRequest request, String host) protected URI rewriteURI(HttpServletRequest request, String scheme, String host)
{ {
return makeURI(host, request.getRequestURI(), request.getQueryString()); return makeURI(scheme, host, request.getRequestURI(), request.getQueryString());
} }
protected static URI makeURI(String host, String requestURI, String rawQueryString) protected static URI makeURI(String scheme, String host, String requestURI, String rawQueryString)
{ {
try { try {
return new URI( return new URI(
"http", scheme,
host, host,
requestURI, requestURI,
rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"), rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"),

View File

@ -19,8 +19,6 @@
package io.druid.server.router; package io.druid.server.router;
import com.google.common.base.Function;
import com.google.common.collect.FluentIterable;
import com.google.inject.Inject; import com.google.inject.Inject;
import com.metamx.emitter.EmittingLogger; import com.metamx.emitter.EmittingLogger;
import io.druid.client.selector.Server; import io.druid.client.selector.Server;
@ -31,6 +29,7 @@ import io.druid.query.Query;
import java.util.Collection; import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
/** /**
*/ */
@ -62,30 +61,14 @@ public class QueryHostFinder
return findServerInner(selected); return findServerInner(selected);
} }
public Collection<String> getAllHosts() public Collection<Server> getAllServers()
{ {
return FluentIterable return ((Collection<List<Server>>) hostSelector.getAllBrokers().values()).stream()
.from((Collection<List<Server>>) hostSelector.getAllBrokers().values()) .flatMap(Collection::stream)
.transformAndConcat( .collect(Collectors.toList());
new Function<List<Server>, Iterable<Server>>()
{
@Override
public Iterable<Server> apply(List<Server> input)
{
return input;
}
}
).transform(new Function<Server, String>()
{
@Override
public String apply(Server input)
{
return input.getHost();
}
}).toList();
} }
public <T> String getHost(Query<T> query) public <T> Server getServer(Query<T> query)
{ {
Server server = findServer(query); Server server = findServer(query);
@ -97,13 +80,12 @@ public class QueryHostFinder
throw new ISE("No server found for query[%s]", query); throw new ISE("No server found for query[%s]", query);
} }
final String host = server.getHost(); log.debug("Selected [%s]", server.getHost());
log.debug("Selected [%s]", host);
return host; return server;
} }
public String getDefaultHost() public Server getDefaultServer()
{ {
Server server = findDefaultServer(); Server server = findDefaultServer();
@ -115,7 +97,7 @@ public class QueryHostFinder
throw new ISE("No default server found!"); throw new ISE("No default server found!");
} }
return server.getHost(); return server;
} }
private Server findServerInner(final Pair<String, Server> selected) private Server findServerInner(final Pair<String, Server> selected)

View File

@ -29,7 +29,6 @@ import com.google.inject.Injector;
import com.google.inject.Key; import com.google.inject.Key;
import com.google.inject.Module; import com.google.inject.Module;
import com.google.inject.servlet.GuiceFilter; import com.google.inject.servlet.GuiceFilter;
import io.druid.common.utils.SocketUtil; import io.druid.common.utils.SocketUtil;
import io.druid.guice.GuiceInjectors; import io.druid.guice.GuiceInjectors;
import io.druid.guice.Jerseys; import io.druid.guice.Jerseys;
@ -108,7 +107,9 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
public void configure(Binder binder) public void configure(Binder binder)
{ {
JsonConfigProvider.bindInstance( JsonConfigProvider.bindInstance(
binder, Key.get(DruidNode.class, Self.class), new DruidNode("test", "localhost", null, null, new ServerConfig()) binder,
Key.get(DruidNode.class, Self.class),
new DruidNode("test", "localhost", null, null, new ServerConfig())
); );
binder.bind(JettyServerInitializer.class).to(ProxyJettyServerInit.class).in(LazySingleton.class); binder.bind(JettyServerInitializer.class).to(ProxyJettyServerInit.class).in(LazySingleton.class);
Jerseys.addResource(binder, SlowResource.class); Jerseys.addResource(binder, SlowResource.class);
@ -197,24 +198,24 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
final QueryHostFinder hostFinder = new QueryHostFinder(null) final QueryHostFinder hostFinder = new QueryHostFinder(null)
{ {
@Override @Override
public String getHost(Query query) public io.druid.client.selector.Server getServer(Query query)
{ {
return "localhost:" + node.getPlaintextPort(); return new TestServer("http", "localhost", node.getPlaintextPort());
} }
@Override @Override
public String getDefaultHost() public io.druid.client.selector.Server getDefaultServer()
{ {
return "localhost:" + node.getPlaintextPort(); return new TestServer("http", "localhost", node.getPlaintextPort());
} }
@Override @Override
public Collection<String> getAllHosts() public Collection<io.druid.client.selector.Server> getAllServers()
{ {
return ImmutableList.of( return ImmutableList.of(
"localhost:" + node.getPlaintextPort(), new TestServer("http", "localhost", node.getPlaintextPort()),
"localhost:" + port1, new TestServer("http", "localhost", port1),
"localhost:" + port2 new TestServer("http", "localhost", port2)
); );
} }
}; };
@ -241,9 +242,9 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
) )
{ {
@Override @Override
protected URI rewriteURI(HttpServletRequest request, String host) protected URI rewriteURI(HttpServletRequest request, String scheme, String host)
{ {
String uri = super.rewriteURI(request, host).toString(); String uri = super.rewriteURI(request, scheme, host).toString();
if (uri.contains("/druid/v2")) { if (uri.contains("/druid/v2")) {
return URI.create(uri.replace("/druid/v2", "/default")); return URI.create(uri.replace("/druid/v2", "/default"));
} }
@ -272,7 +273,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
// test params // test params
Assert.assertEquals( Assert.assertEquals(
new URI("http://localhost:1234/some/path?param=1"), new URI("http://localhost:1234/some/path?param=1"),
AsyncQueryForwardingServlet.makeURI("localhost:1234", "/some/path", "param=1") AsyncQueryForwardingServlet.makeURI("http", "localhost:1234", "/some/path", "param=1")
); );
// HttpServletRequest.getQueryString returns encoded form // HttpServletRequest.getQueryString returns encoded form
@ -280,6 +281,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
Assert.assertEquals( Assert.assertEquals(
"http://[2a00:1450:4007:805::1007]:1234/some/path?param=1&param2=%E2%82%AC", "http://[2a00:1450:4007:805::1007]:1234/some/path?param=1&param2=%E2%82%AC",
AsyncQueryForwardingServlet.makeURI( AsyncQueryForwardingServlet.makeURI(
"http",
HostAndPort.fromParts("2a00:1450:4007:805::1007", 1234).toString(), HostAndPort.fromParts("2a00:1450:4007:805::1007", 1234).toString(),
"/some/path", "/some/path",
"param=1&param2=%E2%82%AC" "param=1&param2=%E2%82%AC"
@ -289,7 +291,46 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
// test null query // test null query
Assert.assertEquals( Assert.assertEquals(
new URI("http://localhost/"), new URI("http://localhost/"),
AsyncQueryForwardingServlet.makeURI("localhost", "/", null) AsyncQueryForwardingServlet.makeURI("http", "localhost", "/", null)
); );
} }
private static class TestServer implements io.druid.client.selector.Server
{
private final String scheme;
private final String address;
private final int port;
public TestServer(String scheme, String address, int port)
{
this.scheme = scheme;
this.address = address;
this.port = port;
}
@Override
public String getScheme()
{
return scheme;
}
@Override
public String getHost()
{
return address + ":" + port;
}
@Override
public String getAddress()
{
return address;
}
@Override
public int getPort()
{
return port;
}
}
} }