From 594a66f3c012bcab980232bfd2c9b37a7b8d2388 Mon Sep 17 00:00:00 2001 From: Parag Jain Date: Mon, 28 Aug 2017 17:03:43 -0500 Subject: [PATCH] add scheme to AsyncQueryForwardingServlet (#4688) * add scheme to AsyncQueryForwardingServlet * add sslContext binding for Router --- .../java/io/druid/https/SSLContextModule.java | 2 + .../server/AsyncQueryForwardingServlet.java | 46 ++++++------- .../druid/server/router/QueryHostFinder.java | 38 +++------- .../AsyncQueryForwardingServletTest.java | 69 +++++++++++++++---- 4 files changed, 90 insertions(+), 65 deletions(-) diff --git a/extensions-core/simple-client-sslcontext/src/main/java/io/druid/https/SSLContextModule.java b/extensions-core/simple-client-sslcontext/src/main/java/io/druid/https/SSLContextModule.java index 9705ee90954..9de19dbbd3e 100644 --- a/extensions-core/simple-client-sslcontext/src/main/java/io/druid/https/SSLContextModule.java +++ b/extensions-core/simple-client-sslcontext/src/main/java/io/druid/https/SSLContextModule.java @@ -26,6 +26,7 @@ import io.druid.guice.JsonConfigProvider; import io.druid.guice.annotations.Client; import io.druid.guice.annotations.Global; import io.druid.initialization.DruidModule; +import io.druid.server.router.Router; import javax.net.ssl.SSLContext; import java.util.List; @@ -46,5 +47,6 @@ public class SSLContextModule implements DruidModule binder.bind(SSLContext.class).toProvider(SSLContextProvider.class); binder.bind(SSLContext.class).annotatedWith(Global.class).toProvider(SSLContextProvider.class); binder.bind(SSLContext.class).annotatedWith(Client.class).toProvider(SSLContextProvider.class); + binder.bind(SSLContext.class).annotatedWith(Router.class).toProvider(SSLContextProvider.class); } } diff --git a/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java b/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java index 84f4485dcf2..849dfd82e68 100644 --- a/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java +++ b/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java @@ -28,6 +28,7 @@ import com.google.inject.Inject; import com.google.inject.Provider; import com.metamx.emitter.EmittingLogger; import com.metamx.emitter.service.ServiceEmitter; +import io.druid.client.selector.Server; import io.druid.guice.annotations.Json; import io.druid.guice.annotations.Smile; import io.druid.guice.http.DruidHttpClientConfig; @@ -72,6 +73,7 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu private static final String APPLICATION_SMILE = "application/smile"; private static final String HOST_ATTRIBUTE = "io.druid.proxy.to.host"; + private static final String SCHEME_ATTRIBUTE = "io.druid.proxy.to.host.scheme"; private static final String QUERY_ATTRIBUTE = "io.druid.proxy.query"; private static final String OBJECTMAPPER_ATTRIBUTE = "io.druid.proxy.objectMapper"; @@ -169,35 +171,31 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper; request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper); - final String defaultHost = hostFinder.getDefaultHost(); - request.setAttribute(HOST_ATTRIBUTE, defaultHost); + final Server defaultServer = hostFinder.getDefaultServer(); + request.setAttribute(HOST_ATTRIBUTE, defaultServer.getHost()); + request.setAttribute(SCHEME_ATTRIBUTE, defaultServer.getScheme()); final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2"); if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) { // query cancellation request - for (final String host : hostFinder.getAllHosts()) { + for (final Server server: hostFinder.getAllServers()) { // send query cancellation to all brokers this query may have gone to // to keep the code simple, the proxy servlet will also send a request to one of the default brokers - if (!host.equals(defaultHost)) { + if (!server.getHost().equals(defaultServer.getHost())) { // issue async requests broadcastClient - .newRequest(rewriteURI(request, host)) + .newRequest(rewriteURI(request, server.getScheme(), server.getHost())) .method(HttpMethod.DELETE) .timeout(CANCELLATION_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS) .send( - new Response.CompleteListener() - { - @Override - public void onComplete(Result result) - { - if (result.isFailed()) { - log.warn( - result.getFailure(), - "Failed to forward cancellation request to [%s]", - host - ); - } + result -> { + if (result.isFailed()) { + log.warn( + result.getFailure(), + "Failed to forward cancellation request to [%s]", + server.getHost() + ); } } ); @@ -209,7 +207,9 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu try { Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class); if (inputQuery != null) { - request.setAttribute(HOST_ATTRIBUTE, hostFinder.getHost(inputQuery)); + final Server server = hostFinder.getServer(inputQuery); + request.setAttribute(HOST_ATTRIBUTE, server.getHost()); + request.setAttribute(SCHEME_ATTRIBUTE, server.getScheme()); if (inputQuery.getId() == null) { inputQuery = inputQuery.withId(UUID.randomUUID().toString()); } @@ -289,19 +289,19 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu @Override protected String rewriteTarget(HttpServletRequest request) { - return rewriteURI(request, (String) request.getAttribute(HOST_ATTRIBUTE)).toString(); + return rewriteURI(request, (String) request.getAttribute(SCHEME_ATTRIBUTE), (String) request.getAttribute(HOST_ATTRIBUTE)).toString(); } - protected URI rewriteURI(HttpServletRequest request, String host) + protected URI rewriteURI(HttpServletRequest request, String scheme, String host) { - return makeURI(host, request.getRequestURI(), request.getQueryString()); + return makeURI(scheme, host, request.getRequestURI(), request.getQueryString()); } - protected static URI makeURI(String host, String requestURI, String rawQueryString) + protected static URI makeURI(String scheme, String host, String requestURI, String rawQueryString) { try { return new URI( - "http", + scheme, host, requestURI, rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"), diff --git a/server/src/main/java/io/druid/server/router/QueryHostFinder.java b/server/src/main/java/io/druid/server/router/QueryHostFinder.java index 1d8f14087ee..244bd523583 100644 --- a/server/src/main/java/io/druid/server/router/QueryHostFinder.java +++ b/server/src/main/java/io/druid/server/router/QueryHostFinder.java @@ -19,8 +19,6 @@ package io.druid.server.router; -import com.google.common.base.Function; -import com.google.common.collect.FluentIterable; import com.google.inject.Inject; import com.metamx.emitter.EmittingLogger; import io.druid.client.selector.Server; @@ -31,6 +29,7 @@ import io.druid.query.Query; import java.util.Collection; import java.util.List; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; /** */ @@ -62,30 +61,14 @@ public class QueryHostFinder return findServerInner(selected); } - public Collection getAllHosts() + public Collection getAllServers() { - return FluentIterable - .from((Collection>) hostSelector.getAllBrokers().values()) - .transformAndConcat( - new Function, Iterable>() - { - @Override - public Iterable apply(List input) - { - return input; - } - } - ).transform(new Function() - { - @Override - public String apply(Server input) - { - return input.getHost(); - } - }).toList(); + return ((Collection>) hostSelector.getAllBrokers().values()).stream() + .flatMap(Collection::stream) + .collect(Collectors.toList()); } - public String getHost(Query query) + public Server getServer(Query query) { Server server = findServer(query); @@ -97,13 +80,12 @@ public class QueryHostFinder throw new ISE("No server found for query[%s]", query); } - final String host = server.getHost(); - log.debug("Selected [%s]", host); + log.debug("Selected [%s]", server.getHost()); - return host; + return server; } - public String getDefaultHost() + public Server getDefaultServer() { Server server = findDefaultServer(); @@ -115,7 +97,7 @@ public class QueryHostFinder throw new ISE("No default server found!"); } - return server.getHost(); + return server; } private Server findServerInner(final Pair selected) diff --git a/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java b/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java index 60b5ffd28b8..a3ca079f9cc 100644 --- a/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java +++ b/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java @@ -29,7 +29,6 @@ import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.servlet.GuiceFilter; - import io.druid.common.utils.SocketUtil; import io.druid.guice.GuiceInjectors; import io.druid.guice.Jerseys; @@ -108,7 +107,9 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest public void configure(Binder binder) { JsonConfigProvider.bindInstance( - binder, Key.get(DruidNode.class, Self.class), new DruidNode("test", "localhost", null, null, new ServerConfig()) + binder, + Key.get(DruidNode.class, Self.class), + new DruidNode("test", "localhost", null, null, new ServerConfig()) ); binder.bind(JettyServerInitializer.class).to(ProxyJettyServerInit.class).in(LazySingleton.class); Jerseys.addResource(binder, SlowResource.class); @@ -197,24 +198,24 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest final QueryHostFinder hostFinder = new QueryHostFinder(null) { @Override - public String getHost(Query query) + public io.druid.client.selector.Server getServer(Query query) { - return "localhost:" + node.getPlaintextPort(); + return new TestServer("http", "localhost", node.getPlaintextPort()); } @Override - public String getDefaultHost() + public io.druid.client.selector.Server getDefaultServer() { - return "localhost:" + node.getPlaintextPort(); + return new TestServer("http", "localhost", node.getPlaintextPort()); } @Override - public Collection getAllHosts() + public Collection getAllServers() { return ImmutableList.of( - "localhost:" + node.getPlaintextPort(), - "localhost:" + port1, - "localhost:" + port2 + new TestServer("http", "localhost", node.getPlaintextPort()), + new TestServer("http", "localhost", port1), + new TestServer("http", "localhost", port2) ); } }; @@ -241,9 +242,9 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ) { @Override - protected URI rewriteURI(HttpServletRequest request, String host) + protected URI rewriteURI(HttpServletRequest request, String scheme, String host) { - String uri = super.rewriteURI(request, host).toString(); + String uri = super.rewriteURI(request, scheme, host).toString(); if (uri.contains("/druid/v2")) { return URI.create(uri.replace("/druid/v2", "/default")); } @@ -272,7 +273,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest // test params Assert.assertEquals( new URI("http://localhost:1234/some/path?param=1"), - AsyncQueryForwardingServlet.makeURI("localhost:1234", "/some/path", "param=1") + AsyncQueryForwardingServlet.makeURI("http", "localhost:1234", "/some/path", "param=1") ); // HttpServletRequest.getQueryString returns encoded form @@ -280,6 +281,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest Assert.assertEquals( "http://[2a00:1450:4007:805::1007]:1234/some/path?param=1¶m2=%E2%82%AC", AsyncQueryForwardingServlet.makeURI( + "http", HostAndPort.fromParts("2a00:1450:4007:805::1007", 1234).toString(), "/some/path", "param=1¶m2=%E2%82%AC" @@ -289,7 +291,46 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest // test null query Assert.assertEquals( new URI("http://localhost/"), - AsyncQueryForwardingServlet.makeURI("localhost", "/", null) + AsyncQueryForwardingServlet.makeURI("http", "localhost", "/", null) ); } + + private static class TestServer implements io.druid.client.selector.Server + { + + private final String scheme; + private final String address; + private final int port; + + public TestServer(String scheme, String address, int port) + { + this.scheme = scheme; + this.address = address; + this.port = port; + } + + @Override + public String getScheme() + { + return scheme; + } + + @Override + public String getHost() + { + return address + ":" + port; + } + + @Override + public String getAddress() + { + return address; + } + + @Override + public int getPort() + { + return port; + } + } }