diff --git a/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java b/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java index 0d0615393b7..5379effff7a 100644 --- a/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java +++ b/processing/src/test/java/io/druid/query/groupby/GroupByQueryRunnerTest.java @@ -154,7 +154,7 @@ import java.util.concurrent.Executors; @RunWith(Parameterized.class) public class GroupByQueryRunnerTest { - public static final ObjectMapper DEFAULT_MAPPER = TestHelper.getSmileMapper(); + public static final ObjectMapper DEFAULT_MAPPER = TestHelper.makeSmileMapper(); public static final DruidProcessingConfig DEFAULT_PROCESSING_CONFIG = new DruidProcessingConfig() { @Override diff --git a/processing/src/test/java/io/druid/segment/TestHelper.java b/processing/src/test/java/io/druid/segment/TestHelper.java index 976aa9fecfd..160faee74bb 100644 --- a/processing/src/test/java/io/druid/segment/TestHelper.java +++ b/processing/src/test/java/io/druid/segment/TestHelper.java @@ -82,7 +82,7 @@ public class TestHelper return mapper; } - public static ObjectMapper getSmileMapper() + public static ObjectMapper makeSmileMapper() { final ObjectMapper mapper = new DefaultObjectMapper(); mapper.setInjectableValues( diff --git a/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java b/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java index 689f3f7456d..927d9a3a214 100644 --- a/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java +++ b/server/src/main/java/io/druid/server/AsyncQueryForwardingServlet.java @@ -26,14 +26,15 @@ import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import com.google.inject.Provider; -import io.druid.java.util.emitter.EmittingLogger; -import io.druid.java.util.emitter.service.ServiceEmitter; import io.druid.client.selector.Server; import io.druid.guice.annotations.Json; import io.druid.guice.annotations.Smile; import io.druid.guice.http.DruidHttpClientConfig; import io.druid.java.util.common.DateTimes; import io.druid.java.util.common.IAE; +import io.druid.java.util.common.jackson.JacksonUtils; +import io.druid.java.util.emitter.EmittingLogger; +import io.druid.java.util.emitter.service.ServiceEmitter; import io.druid.query.DruidMetrics; import io.druid.query.GenericQueryMetricsFactory; import io.druid.query.Query; @@ -174,31 +175,34 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu final ObjectMapper objectMapper = isSmile ? smileMapper : jsonMapper; request.setAttribute(OBJECTMAPPER_ATTRIBUTE, objectMapper); - final Server defaultServer = hostFinder.getDefaultServer(); - request.setAttribute(HOST_ATTRIBUTE, defaultServer.getHost()); - request.setAttribute(SCHEME_ATTRIBUTE, defaultServer.getScheme()); + final String requestURI = request.getRequestURI(); + final String method = request.getMethod(); + final Server targetServer; // The Router does not have the ability to look inside SQL queries and route them intelligently, so just treat // them as a generic request. - final boolean isQueryEndpoint = request.getRequestURI().startsWith("/druid/v2") - && !request.getRequestURI().startsWith("/druid/v2/sql"); + final boolean isQueryEndpoint = requestURI.startsWith("/druid/v2") + && !requestURI.startsWith("/druid/v2/sql"); - final boolean isAvatica = request.getRequestURI().startsWith("/druid/v2/sql/avatica"); + final boolean isAvatica = requestURI.startsWith("/druid/v2/sql/avatica"); if (isAvatica) { - Map requestMap = objectMapper.readValue(request.getInputStream(), Map.class); + Map requestMap = objectMapper.readValue( + request.getInputStream(), + JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT + ); String connectionId = getAvaticaConnectionId(requestMap); - Server targetServer = hostFinder.findServerAvatica(connectionId); + targetServer = hostFinder.findServerAvatica(connectionId); byte[] requestBytes = objectMapper.writeValueAsBytes(requestMap); - request.setAttribute(HOST_ATTRIBUTE, targetServer.getHost()); - request.setAttribute(SCHEME_ATTRIBUTE, targetServer.getScheme()); request.setAttribute(AVATICA_QUERY_ATTRIBUTE, requestBytes); - } else if (isQueryEndpoint && HttpMethod.DELETE.is(request.getMethod())) { + } else if (isQueryEndpoint && HttpMethod.DELETE.is(method)) { // query cancellation request + targetServer = hostFinder.pickDefaultServer(); + for (final Server server : hostFinder.getAllServers()) { // send query cancellation to all brokers this query may have gone to - // to keep the code simple, the proxy servlet will also send a request to one of the default brokers - if (!server.getHost().equals(defaultServer.getHost())) { + // to keep the code simple, the proxy servlet will also send a request to the default targetServer. + if (!server.getHost().equals(targetServer.getHost())) { // issue async requests Response.CompleteListener completeListener = result -> { if (result.isFailed()) { @@ -220,17 +224,17 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu } interruptedQueryCount.incrementAndGet(); } - } else if (isQueryEndpoint && HttpMethod.POST.is(request.getMethod())) { + } else if (isQueryEndpoint && HttpMethod.POST.is(method)) { // query request try { Query inputQuery = objectMapper.readValue(request.getInputStream(), Query.class); if (inputQuery != null) { - final Server server = hostFinder.getServer(inputQuery); - request.setAttribute(HOST_ATTRIBUTE, server.getHost()); - request.setAttribute(SCHEME_ATTRIBUTE, server.getScheme()); + targetServer = hostFinder.pickServer(inputQuery); if (inputQuery.getId() == null) { inputQuery = inputQuery.withId(UUID.randomUUID().toString()); } + } else { + targetServer = hostFinder.pickDefaultServer(); } request.setAttribute(QUERY_ATTRIBUTE, inputQuery); } @@ -258,8 +262,22 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu handleException(response, objectMapper, e); return; } + } else { + targetServer = hostFinder.pickDefaultServer(); } + request.setAttribute(HOST_ATTRIBUTE, targetServer.getHost()); + request.setAttribute(SCHEME_ATTRIBUTE, targetServer.getScheme()); + + doService(request, response); + } + + protected void doService( + HttpServletRequest request, + HttpServletResponse response + ) throws ServletException, IOException + { + // Just call the superclass service method. Overriden in tests. super.service(request, response); } @@ -318,7 +336,11 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu @Override protected String rewriteTarget(HttpServletRequest request) { - return rewriteURI(request, (String) request.getAttribute(SCHEME_ATTRIBUTE), (String) request.getAttribute(HOST_ATTRIBUTE)).toString(); + return rewriteURI( + request, + (String) request.getAttribute(SCHEME_ATTRIBUTE), + (String) request.getAttribute(HOST_ATTRIBUTE) + ).toString(); } protected URI rewriteURI(HttpServletRequest request, String scheme, String host) diff --git a/server/src/main/java/io/druid/server/router/QueryHostFinder.java b/server/src/main/java/io/druid/server/router/QueryHostFinder.java index 527fd6340d6..fa077ef2d54 100644 --- a/server/src/main/java/io/druid/server/router/QueryHostFinder.java +++ b/server/src/main/java/io/druid/server/router/QueryHostFinder.java @@ -20,10 +20,10 @@ package io.druid.server.router; import com.google.inject.Inject; -import io.druid.java.util.emitter.EmittingLogger; import io.druid.client.selector.Server; import io.druid.java.util.common.ISE; import io.druid.java.util.common.Pair; +import io.druid.java.util.emitter.EmittingLogger; import io.druid.query.Query; import java.util.Collection; @@ -90,7 +90,7 @@ public class QueryHostFinder return chosenServer; } - public Server getServer(Query query) + public Server pickServer(Query query) { Server server = findServer(query); @@ -107,7 +107,7 @@ public class QueryHostFinder return server; } - public Server getDefaultServer() + public Server pickDefaultServer() { Server server = findDefaultServer(); diff --git a/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java b/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java index 4118e7ae4f9..c587ad8759e 100644 --- a/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java +++ b/server/src/test/java/io/druid/server/AsyncQueryForwardingServletTest.java @@ -39,21 +39,24 @@ import io.druid.guice.annotations.Self; import io.druid.guice.annotations.Smile; import io.druid.guice.http.DruidHttpClientConfig; import io.druid.initialization.Initialization; +import io.druid.java.util.common.granularity.Granularities; import io.druid.java.util.common.lifecycle.Lifecycle; import io.druid.query.DefaultGenericQueryMetricsFactory; +import io.druid.query.Druids; import io.druid.query.MapQueryToolChestWarehouse; import io.druid.query.Query; -import io.druid.query.QueryToolChest; +import io.druid.query.timeseries.TimeseriesQuery; +import io.druid.segment.TestHelper; import io.druid.server.initialization.BaseJettyTest; import io.druid.server.initialization.jetty.JettyServerInitUtils; import io.druid.server.initialization.jetty.JettyServerInitializer; -import io.druid.server.log.RequestLogger; import io.druid.server.metrics.NoopServiceEmitter; import io.druid.server.router.QueryHostFinder; import io.druid.server.router.RendezvousHashAvaticaConnectionBalancer; import io.druid.server.security.AllowAllAuthorizer; import io.druid.server.security.Authorizer; import io.druid.server.security.AuthorizerMapper; +import org.easymock.EasyMock; import org.eclipse.jetty.client.HttpClient; import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Server; @@ -66,16 +69,20 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.servlet.ReadListener; import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.net.HttpURLConnection; import java.net.URI; import java.net.URL; import java.util.Collection; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLong; public class AsyncQueryForwardingServletTest extends BaseJettyTest { @@ -116,7 +123,8 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ); binder.bind(JettyServerInitializer.class).to(ProxyJettyServerInit.class).in(LazySingleton.class); binder.bind(AuthorizerMapper.class).toInstance( - new AuthorizerMapper(null) { + new AuthorizerMapper(null) + { @Override public Authorizer getAuthorizer(String name) @@ -173,6 +181,97 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest latch.await(); } + @Test + public void testQueryProxy() throws Exception + { + final ObjectMapper jsonMapper = TestHelper.makeJsonMapper(); + final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder() + .dataSource("foo") + .intervals("2000/P1D") + .granularity(Granularities.ALL) + .context(ImmutableMap.of("queryId", "dummy")) + .build(); + + final QueryHostFinder hostFinder = EasyMock.createMock(QueryHostFinder.class); + EasyMock.expect(hostFinder.pickServer(query)).andReturn(new TestServer("http", "1.2.3.4", 9999)).once(); + EasyMock.replay(hostFinder); + + final HttpServletRequest requestMock = EasyMock.createMock(HttpServletRequest.class); + final ByteArrayInputStream inputStream = new ByteArrayInputStream(jsonMapper.writeValueAsBytes(query)); + final ServletInputStream servletInputStream = new ServletInputStream() + { + private boolean finished; + + @Override + public boolean isFinished() + { + return finished; + } + + @Override + public boolean isReady() + { + return true; + } + + @Override + public void setReadListener(final ReadListener readListener) + { + // do nothing + } + + @Override + public int read() + { + final int b = inputStream.read(); + if (b < 0) { + finished = true; + } + return b; + } + }; + EasyMock.expect(requestMock.getContentType()).andReturn("application/json").times(2); + requestMock.setAttribute("io.druid.proxy.objectMapper", jsonMapper); + EasyMock.expectLastCall(); + EasyMock.expect(requestMock.getRequestURI()).andReturn("/druid/v2/"); + EasyMock.expect(requestMock.getMethod()).andReturn("POST"); + EasyMock.expect(requestMock.getInputStream()).andReturn(servletInputStream); + requestMock.setAttribute("io.druid.proxy.query", query); + requestMock.setAttribute("io.druid.proxy.to.host", "1.2.3.4:9999"); + requestMock.setAttribute("io.druid.proxy.to.host.scheme", "http"); + EasyMock.expectLastCall(); + EasyMock.replay(requestMock); + + final AtomicLong didService = new AtomicLong(); + final AsyncQueryForwardingServlet servlet = new AsyncQueryForwardingServlet( + new MapQueryToolChestWarehouse(ImmutableMap.of()), + jsonMapper, + TestHelper.makeSmileMapper(), + hostFinder, + null, + null, + new NoopServiceEmitter(), + requestLogLine -> { /* noop */ }, + new DefaultGenericQueryMetricsFactory(jsonMapper) + ) + { + @Override + protected void doService( + final HttpServletRequest request, + final HttpServletResponse response + ) + { + didService.incrementAndGet(); + } + }; + + servlet.service(requestMock, null); + + // This test is mostly about verifying that the servlet calls the right methods the right number of times. + EasyMock.verify(hostFinder, requestMock); + Assert.assertEquals(1, didService.get()); + } + private static Server makeTestDeleteServer(int port, final CountDownLatch latch) { Server server = new Server(port); @@ -211,13 +310,13 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest final QueryHostFinder hostFinder = new QueryHostFinder(null, new RendezvousHashAvaticaConnectionBalancer()) { @Override - public io.druid.client.selector.Server getServer(Query query) + public io.druid.client.selector.Server pickServer(Query query) { return new TestServer("http", "localhost", node.getPlaintextPort()); } @Override - public io.druid.client.selector.Server getDefaultServer() + public io.druid.client.selector.Server pickDefaultServer() { return new TestServer("http", "localhost", node.getPlaintextPort()); } @@ -236,21 +335,14 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ObjectMapper jsonMapper = injector.getInstance(ObjectMapper.class); ServletHolder holder = new ServletHolder( new AsyncQueryForwardingServlet( - new MapQueryToolChestWarehouse(ImmutableMap., QueryToolChest>of()), + new MapQueryToolChestWarehouse(ImmutableMap.of()), jsonMapper, injector.getInstance(Key.get(ObjectMapper.class, Smile.class)), hostFinder, injector.getProvider(HttpClient.class), injector.getInstance(DruidHttpClientConfig.class), new NoopServiceEmitter(), - new RequestLogger() - { - @Override - public void log(RequestLogLine requestLogLine) throws IOException - { - // noop - } - }, + requestLogLine -> { /* noop */ }, new DefaultGenericQueryMetricsFactory(jsonMapper) ) {