diff --git a/server/src/main/java/io/druid/curator/discovery/ServerDiscoverySelector.java b/server/src/main/java/io/druid/curator/discovery/ServerDiscoverySelector.java index c94cd6cb422..a52a4a2cdd9 100644 --- a/server/src/main/java/io/druid/curator/discovery/ServerDiscoverySelector.java +++ b/server/src/main/java/io/druid/curator/discovery/ServerDiscoverySelector.java @@ -17,6 +17,8 @@ package io.druid.curator.discovery; +import com.google.common.base.Function; +import com.google.common.collect.Collections2; import com.google.common.net.HostAndPort; import com.metamx.common.lifecycle.LifecycleStart; import com.metamx.common.lifecycle.LifecycleStop; @@ -27,6 +29,8 @@ import org.apache.curator.x.discovery.ServiceInstance; import org.apache.curator.x.discovery.ServiceProvider; import java.io.IOException; +import java.util.Collection; +import java.util.Collections; /** */ @@ -41,6 +45,40 @@ public class ServerDiscoverySelector implements DiscoverySelector this.serviceProvider = serviceProvider; } + private static final Function TO_SERVER = new Function() + { + @Override + public Server apply(final ServiceInstance instance) + { + return new Server() + { + @Override + public String getHost() + { + return HostAndPort.fromParts(getAddress(), getPort()).toString(); + } + + @Override + public String getAddress() + { + return instance.getAddress(); + } + + @Override + public int getPort() + { + return instance.getPort(); + } + + @Override + public String getScheme() + { + return "http"; + } + }; + } + }; + @Override public Server pick() { @@ -58,32 +96,18 @@ public class ServerDiscoverySelector implements DiscoverySelector return null; } - return new Server() - { - @Override - public String getHost() - { - return HostAndPort.fromParts(getAddress(), getPort()).toString(); - } + return TO_SERVER.apply(instance); + } - @Override - public String getAddress() - { - return instance.getAddress(); - } - - @Override - public int getPort() - { - return instance.getPort(); - } - - @Override - public String getScheme() - { - return "http"; - } - }; + public Collection getAll() + { + try { + return Collections2.transform(serviceProvider.getAllInstances(), TO_SERVER); + } + catch (Exception e) { + log.info(e, "Unable to get all instances"); + return Collections.emptyList(); + } } @LifecycleStart diff --git a/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java b/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java index 0d1774cf753..d89cefd83c4 100644 --- a/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java +++ b/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java @@ -47,7 +47,10 @@ import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.MediaType; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLDecoder; import java.util.UUID; import java.util.concurrent.TimeUnit; @@ -110,18 +113,46 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet @Override protected void service(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - final boolean isSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(request.getContentType()) || APPLICATION_SMILE.equals(request.getContentType()); + final boolean isSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(request.getContentType()) + || APPLICATION_SMILE.equals(request.getContentType()); final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper; request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper); - String host = hostFinder.getDefaultHost(); - request.setAttribute(HOST_ATTRIBUTE, host); + final String defaultHost = hostFinder.getDefaultHost(); + request.setAttribute(HOST_ATTRIBUTE, defaultHost); - boolean isQuery = request.getMethod().equals(HttpMethod.POST.asString()) && - request.getRequestURI().startsWith("/druid/v2"); + final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2"); - // queries only exist for POST - if (isQuery) { + if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) { + // query cancellation request + for (final String host : hostFinder.getAllHosts()) { + // send query cancellation to all brokers this query may have gone to + // to keep the code simple, the proxy servlet will also send a request to one of the default brokers + if (!host.equals(defaultHost)) { + // issue async requests + getHttpClient() + .newRequest(rewriteURI(request, host)) + .method(HttpMethod.DELETE) + .send( + new Response.CompleteListener() + { + @Override + public void onComplete(Result result) + { + if (result.isFailed()) { + log.warn( + result.getFailure(), + "Failed to forward cancellation request to [%s]", + host + ); + } + } + } + ); + } + } + } else if (isQueryEndpoint && HttpMethod.POST.is(request.getMethod())) { + // query request try { Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class); if (inputQuery != null) { @@ -172,7 +203,8 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet final ObjectMapper objectMapper = (ObjectMapper) request.getAttribute(OBJECTMAPPER_ATTRIBUTE); try { proxyRequest.content(new BytesContentProvider(objectMapper.writeValueAsBytes(query))); - } catch(JsonProcessingException e) { + } + catch (JsonProcessingException e) { Throwables.propagate(e); } } @@ -194,16 +226,29 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet @Override protected URI rewriteURI(HttpServletRequest request) { - final String host = (String) request.getAttribute(HOST_ATTRIBUTE); - final StringBuilder uri = new StringBuilder("http://"); + return rewriteURI(request, (String) request.getAttribute(HOST_ATTRIBUTE)); + } - uri.append(host); - uri.append(request.getRequestURI()); - final String queryString = request.getQueryString(); - if (queryString != null) { - uri.append("?").append(queryString); + protected URI rewriteURI(HttpServletRequest request, String host) + { + return makeURI(host, request.getRequestURI(), request.getQueryString()); + } + + protected static URI makeURI(String host, String requestURI, String rawQueryString) + { + try { + return new URI( + "http", + host, + requestURI, + rawQueryString == null ? null : URLDecoder.decode(rawQueryString, "UTF-8"), + null + ); + } + catch (UnsupportedEncodingException | URISyntaxException e) { + log.error(e, "Unable to rewrite URI [%s]", e.getMessage()); + throw Throwables.propagate(e); } - return URI.create(uri.toString()); } @Override @@ -261,7 +306,7 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet try { emitter.emit( DruidMetrics.makeQueryTimeMetric(jsonMapper, query, req.getRemoteAddr()) - .build("query/time", requestTime) + .build("query/time", requestTime) ); requestLogger.log( diff --git a/server/src/main/java/io/druid/server/router/QueryHostFinder.java b/server/src/main/java/io/druid/server/router/QueryHostFinder.java index 9c5e854bffb..cacbb76a7e3 100644 --- a/server/src/main/java/io/druid/server/router/QueryHostFinder.java +++ b/server/src/main/java/io/druid/server/router/QueryHostFinder.java @@ -17,6 +17,8 @@ package io.druid.server.router; +import com.google.common.base.Function; +import com.google.common.collect.FluentIterable; import com.google.inject.Inject; import com.metamx.common.ISE; import com.metamx.common.Pair; @@ -25,17 +27,18 @@ import io.druid.client.selector.Server; import io.druid.curator.discovery.ServerDiscoverySelector; import io.druid.query.Query; +import java.util.Collection; import java.util.concurrent.ConcurrentHashMap; /** */ -public class QueryHostFinder +public class QueryHostFinder { private static EmittingLogger log = new EmittingLogger(QueryHostFinder.class); private final TieredBrokerHostSelector hostSelector; - private final ConcurrentHashMap serverBackup = new ConcurrentHashMap(); + private final ConcurrentHashMap serverBackup = new ConcurrentHashMap<>(); @Inject public QueryHostFinder( @@ -45,7 +48,7 @@ public class QueryHostFinder this.hostSelector = hostSelector; } - public Server findServer(Query query) + public Server findServer(Query query) { final Pair selected = hostSelector.select(query); return findServerInner(selected); @@ -57,7 +60,30 @@ public class QueryHostFinder return findServerInner(selected); } - public String getHost(Query query) + public Collection getAllHosts() + { + return FluentIterable + .from((Collection) hostSelector.getAllBrokers().values()) + .transformAndConcat( + new Function>() + { + @Override + public Iterable apply(ServerDiscoverySelector input) + { + return input.getAll(); + } + } + ).transform(new Function() + { + @Override + public String apply(Server input) + { + return input.getHost(); + } + }).toList(); + } + + public String getHost(Query query) { Server server = findServer(query); @@ -69,9 +95,10 @@ public class QueryHostFinder throw new ISE("No server found for query[%s]", query); } - log.debug("Selected [%s]", server.getHost()); + final String host = server.getHost(); + log.debug("Selected [%s]", host); - return server.getHost(); + return host; } public String getDefaultHost() diff --git a/server/src/main/java/io/druid/server/router/TieredBrokerHostSelector.java b/server/src/main/java/io/druid/server/router/TieredBrokerHostSelector.java index f061b5c53ae..e8ceb90ace0 100644 --- a/server/src/main/java/io/druid/server/router/TieredBrokerHostSelector.java +++ b/server/src/main/java/io/druid/server/router/TieredBrokerHostSelector.java @@ -34,6 +34,7 @@ import io.druid.server.coordinator.rules.Rule; import org.joda.time.DateTime; import org.joda.time.Interval; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -201,4 +202,9 @@ public class TieredBrokerHostSelector implements HostSelector final ServerDiscoverySelector retVal = selectorMap.get(brokerServiceName); return new Pair<>(brokerServiceName, retVal); } + + public Map getAllBrokers() + { + return Collections.unmodifiableMap(selectorMap); + } } diff --git a/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java b/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java index a9e3191efe5..23e41fd93a3 100644 --- a/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java +++ b/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java @@ -21,6 +21,7 @@ package io.druid.server; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; +import com.google.common.net.HostAndPort; import com.google.inject.Binder; import com.google.inject.Inject; import com.google.inject.Injector; @@ -49,16 +50,22 @@ import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.handler.HandlerList; import org.eclipse.jetty.servlet.DefaultServlet; import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.net.HttpURLConnection; import java.net.URI; import java.net.URL; +import java.util.Collection; +import java.util.concurrent.CountDownLatch; public class AsyncQueryForwardingServletTest extends BaseJettyTest { @@ -122,6 +129,40 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest Assert.assertNotEquals("gzip", postNoGzip.getContentEncoding()); } + @Test(timeout = 60_000) + public void testDeleteBroadcast() throws Exception + { + CountDownLatch latch = new CountDownLatch(2); + makeTestDeleteServer(port + 1, latch).start(); + makeTestDeleteServer(port + 2, latch).start(); + + final URL url = new URL("http://localhost:" + port + "/druid/v2/abc123"); + final HttpURLConnection post = (HttpURLConnection) url.openConnection(); + post.setRequestMethod("DELETE"); + int code = post.getResponseCode(); + Assert.assertEquals(200, code); + + latch.await(); + } + + private static Server makeTestDeleteServer(int port, final CountDownLatch latch) + { + Server server = new Server(port); + ServletHandler handler = new ServletHandler(); + handler.addServletWithMapping(new ServletHolder(new HttpServlet() + { + @Override + protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException + { + latch.countDown(); + resp.setStatus(200); + } + }), "/default/*"); + + server.setHandler(handler); + return server; + } + public static class ProxyJettyServerInit implements JettyServerInitializer { @@ -152,6 +193,16 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest { return "localhost:" + node.getPort(); } + + @Override + public Collection getAllHosts() + { + return ImmutableList.of( + "localhost:" + node.getPort(), + "localhost:" + (node.getPort() + 1), + "localhost:" + (node.getPort() + 2) + ); + } }; ServletHolder holder = new ServletHolder( @@ -173,15 +224,19 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ) { @Override - protected URI rewriteURI(HttpServletRequest request) + protected URI rewriteURI(HttpServletRequest request, String host) { - URI uri = super.rewriteURI(request); - return URI.create(uri.toString().replace("/proxy", "")); + String uri = super.rewriteURI(request, host).toString(); + if (uri.contains("/druid/v2")) { + return URI.create(uri.replace("/druid/v2", "/default")); + } + return URI.create(uri.replace("/proxy", "")); } }); //NOTE: explicit maxThreads to workaround https://tickets.puppetlabs.com/browse/TK-152 holder.setInitParameter("maxThreads", "256"); root.addServlet(holder, "/proxy/*"); + root.addServlet(holder, "/druid/v2/*"); JettyServerInitUtils.addExtensionFilters(root, injector); root.addFilter(JettyServerInitUtils.defaultAsyncGzipFilterHolder(), "/*", null); root.addFilter(GuiceFilter.class, "/slow/*", null); @@ -193,4 +248,32 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest server.setHandler(handlerList); } } + + @Test + public void testRewriteURI() throws Exception + { + + // test params + Assert.assertEquals( + new URI("http://localhost:1234/some/path?param=1"), + AsyncQueryForwardingServlet.makeURI("localhost:1234", "/some/path", "param=1") + ); + + // HttpServletRequest.getQueryString returns encoded form + // use ascii representation in case URI is using non-ascii characters + Assert.assertEquals( + "http://[2a00:1450:4007:805::1007]:1234/some/path?param=1¶m2=%E2%82%AC", + AsyncQueryForwardingServlet.makeURI( + HostAndPort.fromParts("2a00:1450:4007:805::1007", 1234).toString(), + "/some/path", + "param=1¶m2=%E2%82%AC" + ).toASCIIString() + ); + + // test null query + Assert.assertEquals( + new URI("http://localhost/"), + AsyncQueryForwardingServlet.makeURI("localhost", "/", null) + ); + } } diff --git a/server/src/test/java/io/druid/server/initialization/BaseJettyTest.java b/server/src/test/java/io/druid/server/initialization/BaseJettyTest.java index 9a5d6f4603b..fc4867a25af 100644 --- a/server/src/test/java/io/druid/server/initialization/BaseJettyTest.java +++ b/server/src/test/java/io/druid/server/initialization/BaseJettyTest.java @@ -64,6 +64,7 @@ import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.DELETE; import javax.ws.rs.GET; import javax.ws.rs.POST; import javax.ws.rs.Path; @@ -236,6 +237,14 @@ public class BaseJettyTest @Path("/default") public static class DefaultResource { + @DELETE + @Path("{resource}") + @Produces(MediaType.APPLICATION_JSON) + public Response delete() + { + return Response.ok("hello").build(); + } + @GET @Produces(MediaType.APPLICATION_JSON) public Response get()