diff --git a/core/src/main/java/org/apache/druid/query/QueryException.java b/core/src/main/java/org/apache/druid/query/QueryException.java index 93f17b6cff9..3bb896cef85 100644 --- a/core/src/main/java/org/apache/druid/query/QueryException.java +++ b/core/src/main/java/org/apache/druid/query/QueryException.java @@ -152,7 +152,12 @@ public class QueryException extends RuntimeException implements SanitizableExcep protected QueryException(Throwable cause, String errorCode, String errorClass, String host) { - super(cause == null ? null : cause.getMessage(), cause); + this(cause, errorCode, cause == null ? null : cause.getMessage(), errorClass, host); + } + + protected QueryException(Throwable cause, String errorCode, String errorMessage, String errorClass, String host) + { + super(errorMessage, cause); this.errorCode = errorCode; this.errorClass = errorClass; this.host = host; diff --git a/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java b/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java index 446e94b9667..8212ebc344f 100644 --- a/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java +++ b/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java @@ -95,6 +95,24 @@ public class QueryExceptionTest expectFailTypeForCode(FailType.USER_ERROR, QueryException.SQL_QUERY_UNSUPPORTED_ERROR_CODE); } + /** + * This test exists primarily to get branch coverage of the null check on the QueryException constructor. + * The validations done in this test are not actually intended to be set-in-stone or anything. + */ + @Test + public void testCanConstructWithoutThrowable() + { + QueryException exception = new QueryException( + (Throwable) null, + QueryException.UNKNOWN_EXCEPTION_ERROR_CODE, + "java.lang.Exception", + "test" + ); + + Assert.assertEquals(QueryException.UNKNOWN_EXCEPTION_ERROR_CODE, exception.getErrorCode()); + Assert.assertNull(exception.getMessage()); + } + private void expectFailTypeForCode(FailType expected, String code) { QueryException exception = new QueryException(new Exception(), code, "java.lang.Exception", "test"); diff --git a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/db/cache/CoordinatorPollingBasicAuthorizerCacheManager.java b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/db/cache/CoordinatorPollingBasicAuthorizerCacheManager.java index 9c45d9b09f2..3d192e89d0d 100644 --- a/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/db/cache/CoordinatorPollingBasicAuthorizerCacheManager.java +++ b/extensions-core/druid-basic-security/src/main/java/org/apache/druid/security/basic/authorization/db/cache/CoordinatorPollingBasicAuthorizerCacheManager.java @@ -420,14 +420,20 @@ public class CoordinatorPollingBasicAuthorizerCacheManager implements BasicAutho new BytesFullResponseHandler() ); + final HttpResponseStatus status = responseHolder.getStatus(); + // cachedSerializedGroupMappingMap is a new endpoint introduced in Druid 0.17.0. For backwards compatibility, if we // get a 404 from the coordinator we stop retrying. This can happen during a rolling upgrade when a process // running 0.17.0+ tries to access this endpoint on an older coordinator. - if (responseHolder.getStatus().equals(HttpResponseStatus.NOT_FOUND)) { + if (HttpResponseStatus.NOT_FOUND.equals(status)) { LOG.warn("cachedSerializedGroupMappingMap is not available from the coordinator, skipping fetch of group mappings for now."); return null; } + if (!HttpResponseStatus.OK.equals(status)) { + LOG.warn("Got an unexpected response status[%s] when loading group mappings.", status); + } + byte[] groupRoleMapBytes = responseHolder.getContent(); GroupMappingAndRoleMap groupMappingAndRoleMap = objectMapper.readValue( diff --git a/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/cluster/DruidClusterClient.java b/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/cluster/DruidClusterClient.java index 31b2e203945..82d31718795 100644 --- a/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/cluster/DruidClusterClient.java +++ b/integration-tests-ex/cases/src/test/java/org/apache/druid/testsEx/cluster/DruidClusterClient.java @@ -41,7 +41,6 @@ import org.jboss.netty.handler.codec.http.HttpResponseStatus; import javax.inject.Inject; import javax.ws.rs.core.MediaType; - import java.io.IOException; import java.net.URL; import java.util.Map; @@ -336,7 +335,8 @@ public class DruidClusterClient */ public void validate() { - log.info("Starting cluster validation"); + RE exception = new RE("Just building for the stack trace"); + log.info(exception, "Starting cluster validation"); for (ResolvedDruidService service : config.requireDruid().values()) { for (ResolvedInstance instance : service.requireInstances()) { validateInstance(service, instance); @@ -348,28 +348,46 @@ public class DruidClusterClient /** * Validate an instance by waiting for it to report that it is healthy. */ + @SuppressWarnings("BusyWait") private void validateInstance(ResolvedDruidService service, ResolvedInstance instance) { int timeoutMs = config.readyTimeoutSec() * 1000; int pollMs = config.readyPollMs(); long startTime = System.currentTimeMillis(); long updateTime = startTime + 5000; - while (System.currentTimeMillis() - startTime < timeoutMs) { + while (true) { if (isHealthy(service, instance)) { log.info( - "Service %s, host %s is ready", + "Service[%s], host[%s], tag[%s] is ready", service.service(), - instance.clientHost()); + instance.clientHost(), + instance.tag() == null ? "" : instance.tag() + ); return; } long currentTime = System.currentTimeMillis(); if (currentTime > updateTime) { log.info( - "Service %s, host %s not ready, retrying", + "Service[%s], host[%s], tag[%s] not ready, retrying", service.service(), - instance.clientHost()); + instance.clientHost(), + instance.tag() == null ? "" : instance.tag() + ); updateTime = currentTime + 5000; } + final long elapsedTime = System.currentTimeMillis() - startTime; + if (elapsedTime > timeoutMs) { + final RE exception = new RE( + "Service[%s], host[%s], tag[%s] not ready after %,d ms.", + service.service(), + instance.clientHost(), + instance.tag() == null ? "" : instance.tag(), + elapsedTime + ); + // We log the exception here so that the logs include which thread is having the problem + log.error(exception.getMessage()); + throw exception; + } try { Thread.sleep(pollMs); } @@ -377,34 +395,30 @@ public class DruidClusterClient throw new RuntimeException("Interrupted during cluster validation"); } } - throw new RE( - StringUtils.format("Service %s, instance %s not ready after %d ms.", - service.service(), - instance.tag() == null ? "" : instance.tag(), - timeoutMs)); } /** * Wait for an instance to become ready given the URL and a description of * the service. */ + @SuppressWarnings("BusyWait") public void waitForNodeReady(String label, String url) { int timeoutMs = config.readyTimeoutSec() * 1000; int pollMs = config.readyPollMs(); long startTime = System.currentTimeMillis(); - while (System.currentTimeMillis() - startTime < timeoutMs) { + while (true) { if (isHealthy(url)) { - log.info( - "Service %s, url %s is ready", - label, - url); + log.info("Service[%s], url[%s] is ready", label, url); return; } - log.info( - "Service %s, url %s not ready, retrying", - label, - url); + final long elapsedTime = System.currentTimeMillis() - startTime; + if (elapsedTime > timeoutMs) { + final RE re = new RE("Service[%s], url[%s] not ready after %,d ms.", label, url, elapsedTime); + log.error(re.getMessage()); + throw re; + } + log.info("Service[%s], url[%s] not ready, retrying", label, url); try { Thread.sleep(pollMs); } @@ -412,11 +426,6 @@ public class DruidClusterClient throw new RuntimeException("Interrupted while waiting for note to be ready"); } } - throw new RE( - StringUtils.format("Service %s, url %s not ready after %d ms.", - label, - url, - timeoutMs)); } public String nodeUrl(DruidNode node) diff --git a/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java b/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java index 47ba3c90777..56f1a25be89 100644 --- a/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java +++ b/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java @@ -29,7 +29,7 @@ public class BadJsonQueryException extends BadQueryException public BadJsonQueryException(JsonParseException e) { - this(JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS); + this(e, JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS); } @JsonCreator @@ -39,6 +39,16 @@ public class BadJsonQueryException extends BadQueryException @JsonProperty("errorClass") String errorClass ) { - super(errorCode, errorMessage, errorClass); + this(null, errorCode, errorMessage, errorClass); + } + + private BadJsonQueryException( + Throwable cause, + String errorCode, + String errorMessage, + String errorClass + ) + { + super(cause, errorCode, errorMessage, errorClass, null); } } diff --git a/processing/src/main/java/org/apache/druid/query/BadQueryException.java b/processing/src/main/java/org/apache/druid/query/BadQueryException.java index b115cc1170c..e627c966ade 100644 --- a/processing/src/main/java/org/apache/druid/query/BadQueryException.java +++ b/processing/src/main/java/org/apache/druid/query/BadQueryException.java @@ -30,11 +30,16 @@ public abstract class BadQueryException extends QueryException protected BadQueryException(String errorCode, String errorMessage, String errorClass) { - super(errorCode, errorMessage, errorClass, null); + this(errorCode, errorMessage, errorClass, null); } protected BadQueryException(String errorCode, String errorMessage, String errorClass, String host) { - super(errorCode, errorMessage, errorClass, host); + this(null, errorCode, errorMessage, errorClass, host); + } + + protected BadQueryException(Throwable cause, String errorCode, String errorMessage, String errorClass, String host) + { + super(cause, errorCode, errorMessage, errorClass, host); } } diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index 84e6acf24c1..f4a7ab3edb7 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -30,6 +30,7 @@ import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.inject.Inject; import org.apache.druid.client.DirectDruidClient; @@ -216,21 +217,8 @@ public class QueryResource implements QueryCountStatsProvider throw new ForbiddenException(authResult.toString()); } - // We use an async context not because we are actually going to run this async, but because we want to delay - // the decision of what the response code should be until we have gotten the first few data points to return. - // Returning a Response object from this point forward requires that object to know the status code, which we - // don't actually know until we are in the accumulator, but if we try to return a Response object from the - // accumulator, we cannot properly stream results back, because the accumulator won't release control of the - // Response until it has consumed the underlying Sequence. - final AsyncContext asyncContext = req.startAsync(); - - try { - new QueryResourceQueryResultPusher(req, queryLifecycle, io, (HttpServletResponse) asyncContext.getResponse()) - .push(); - } - finally { - asyncContext.complete(); - } + final QueryResourceQueryResultPusher pusher = new QueryResourceQueryResultPusher(req, queryLifecycle, io); + return pusher.push(); } catch (Exception e) { if (e instanceof ForbiddenException && !req.isAsyncStarted()) { @@ -258,6 +246,7 @@ public class QueryResource implements QueryCountStatsProvider out.write(jsonMapper.writeValueAsBytes(responseException)); } } + return null; } finally { asyncContext.complete(); @@ -266,7 +255,6 @@ public class QueryResource implements QueryCountStatsProvider finally { Thread.currentThread().setName(currThreadName); } - return null; } public interface QueryMetricCounter @@ -538,18 +526,18 @@ public class QueryResource implements QueryCountStatsProvider public QueryResourceQueryResultPusher( HttpServletRequest req, QueryLifecycle queryLifecycle, - ResourceIOReaderWriter io, - HttpServletResponse response + ResourceIOReaderWriter io ) { super( - response, + req, QueryResource.this.jsonMapper, QueryResource.this.responseContextConfig, QueryResource.this.selfNode, QueryResource.this.counter, queryLifecycle.getQueryId(), - MediaType.valueOf(io.getResponseWriter().getResponseType()) + MediaType.valueOf(io.getResponseWriter().getResponseType()), + ImmutableMap.of() ); this.req = req; this.queryLifecycle = queryLifecycle; @@ -561,20 +549,27 @@ public class QueryResource implements QueryCountStatsProvider { return new ResultsWriter() { + private QueryResponse queryResponse; + @Override - public QueryResponse start(HttpServletResponse response) + public Response.ResponseBuilder start() { - final QueryResponse queryResponse = queryLifecycle.execute(); + queryResponse = queryLifecycle.execute(); final ResponseContext responseContext = queryResponse.getResponseContext(); final String prevEtag = getPreviousEtag(req); if (prevEtag != null && prevEtag.equals(responseContext.getEntityTag())) { queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1); counter.incrementSuccess(); - response.setStatus(HttpServletResponse.SC_NOT_MODIFIED); - return null; + return Response.status(Status.NOT_MODIFIED); } + return null; + } + + @Override + public QueryResponse getQueryResponse() + { return queryResponse; } diff --git a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java index 44e1f5d6c23..b25e31d363d 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java +++ b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java @@ -36,45 +36,54 @@ import org.apache.druid.query.context.ResponseContext; import org.apache.druid.server.security.ForbiddenException; import javax.annotation.Nullable; +import javax.servlet.AsyncContext; import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; +import javax.ws.rs.core.StreamingOutput; import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; +import java.util.Map; public abstract class QueryResultPusher { private static final Logger log = new Logger(QueryResultPusher.class); - private final HttpServletResponse response; + private final HttpServletRequest request; private final String queryId; private final ObjectMapper jsonMapper; private final ResponseContextConfig responseContextConfig; private final DruidNode selfNode; private final QueryResource.QueryMetricCounter counter; private final MediaType contentType; + private final Map extraHeaders; private StreamingHttpResponseAccumulator accumulator = null; + private AsyncContext asyncContext = null; + private HttpServletResponse response = null; public QueryResultPusher( - HttpServletResponse response, + HttpServletRequest request, ObjectMapper jsonMapper, ResponseContextConfig responseContextConfig, DruidNode selfNode, QueryResource.QueryMetricCounter counter, String queryId, - MediaType contentType + MediaType contentType, + Map extraHeaders ) { - this.response = response; + this.request = request; this.queryId = queryId; this.jsonMapper = jsonMapper; this.responseContextConfig = responseContextConfig; this.selfNode = selfNode; this.counter = counter; this.contentType = contentType; + this.extraHeaders = extraHeaders; } /** @@ -92,23 +101,45 @@ public abstract class QueryResultPusher public abstract void writeException(Exception e, OutputStream out) throws IOException; - public void push() + /** + * Pushes results out. Can sometimes return a JAXRS Response object instead of actually pushing to the output + * stream, primarily for error handling that occurs before switching the servlet to asynchronous mode. + * + * @return null if the response has already been handled and pushed out, or a non-null Response object if it expects + * the container to put the bytes on the wire. + */ + @Nullable + public Response push() { - response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); - ResultsWriter resultsWriter = null; try { resultsWriter = start(); - - final QueryResponse queryResponse = resultsWriter.start(response); - if (queryResponse == null) { - // It's already been handled... - return; + final Response.ResponseBuilder startResponse = resultsWriter.start(); + if (startResponse != null) { + startResponse.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); + for (Map.Entry entry : extraHeaders.entrySet()) { + startResponse.header(entry.getKey(), entry.getValue()); + } + return startResponse.build(); } + final QueryResponse queryResponse = resultsWriter.getQueryResponse(); final Sequence results = queryResponse.getResults(); + // We use an async context not because we are actually going to run this async, but because we want to delay + // the decision of what the response code should be until we have gotten the first few data points to return. + // Returning a Response object from this point forward requires that object to know the status code, which we + // don't actually know until we are in the accumulator, but if we try to return a Response object from the + // accumulator, we cannot properly stream results back, because the accumulator won't release control of the + // Response until it has consumed the underlying Sequence. + asyncContext = request.startAsync(); + response = (HttpServletResponse) asyncContext.getResponse(); + response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); + for (Map.Entry entry : extraHeaders.entrySet()) { + response.setHeader(entry.getKey(), entry.getValue()); + } + accumulator = new StreamingHttpResponseAccumulator(queryResponse.getResponseContext(), resultsWriter); results.accumulate(null, accumulator); @@ -119,8 +150,7 @@ public abstract class QueryResultPusher resultsWriter.recordSuccess(accumulator.getNumBytesSent()); } catch (QueryException e) { - handleQueryException(resultsWriter, e); - return; + return handleQueryException(resultsWriter, e); } catch (RuntimeException re) { if (re instanceof ForbiddenException) { @@ -128,17 +158,15 @@ public abstract class QueryResultPusher // has been committed because the response is committed after results are returned. And, if we started // returning results before a ForbiddenException gets thrown, that means that we've already leaked stuff // that should not have been leaked. I.e. it means, we haven't validated the authorization early enough. - if (response.isCommitted()) { + if (response != null && response.isCommitted()) { log.error(re, "Got a forbidden exception for query[%s] after the response was already committed.", queryId); } throw re; } - handleQueryException(resultsWriter, new QueryInterruptedException(re)); - return; + return handleQueryException(resultsWriter, new QueryInterruptedException(re)); } catch (IOException ioEx) { - handleQueryException(resultsWriter, new QueryInterruptedException(ioEx)); - return; + return handleQueryException(resultsWriter, new QueryInterruptedException(ioEx)); } finally { if (accumulator != null) { @@ -159,10 +187,15 @@ public abstract class QueryResultPusher log.warn(e, "Suppressing exception closing accumulator for query[%s]", queryId); } } + if (asyncContext != null) { + asyncContext.complete(); + } } + return null; } - private void handleQueryException(ResultsWriter resultsWriter, QueryException e) + @Nullable + private Response handleQueryException(ResultsWriter resultsWriter, QueryException e) { if (accumulator != null && accumulator.isInitialized()) { // We already started sending a response when we got the error message. In this case we just give up @@ -176,11 +209,7 @@ public abstract class QueryResultPusher // This case is always a failure because the error happened mid-stream of sending results back. Therefore, // we do not believe that the response stream was actually useable counter.incrementFailed(); - return; - } - - if (response.isCommitted()) { - QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?"); + return null; } final QueryException.FailType failType = e.getFailType(); @@ -206,40 +235,71 @@ public abstract class QueryResultPusher ); counter.incrementFailed(); } - final int responseStatus = failType.getExpectedStatus(); - - response.setStatus(responseStatus); - response.setHeader("Content-Type", contentType.toString()); - try (ServletOutputStream out = response.getOutputStream()) { - writeException(e, out); - } - catch (IOException ioException) { - log.warn( - ioException, - "Suppressing IOException thrown sending error response for query[%s]", - queryId - ); - } resultsWriter.recordFailure(e); + + final int responseStatus = failType.getExpectedStatus(); + + if (response == null) { + // No response object yet, so assume we haven't started the async context and is safe to return Response + final Response.ResponseBuilder bob = Response + .status(responseStatus) + .type(contentType) + .entity((StreamingOutput) output -> { + writeException(e, output); + output.close(); + }); + + bob.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); + for (Map.Entry entry : extraHeaders.entrySet()) { + bob.header(entry.getKey(), entry.getValue()); + } + + return bob.build(); + } else { + if (response.isCommitted()) { + QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?"); + } + + response.setStatus(responseStatus); + response.setHeader("Content-Type", contentType.toString()); + try (ServletOutputStream out = response.getOutputStream()) { + writeException(e, out); + } + catch (IOException ioException) { + log.warn( + ioException, + "Suppressing IOException thrown sending error response for query[%s]", + queryId + ); + } + return null; + } } public interface ResultsWriter extends Closeable { /** - * Runs the query and returns a ResultsWriter from running the query. + * Runs the query and prepares the QueryResponse to be returned *

* This also serves as a hook for any logic that runs on the metadata from a QueryResponse. If this method - * returns {@code null} then the Pusher believes that the response was already handled and skips the rest - * of its logic. As such, any implementation that returns null must make sure that the response has been set - * with a meaningful status, etc. + * returns {@code null} then the Pusher can continue with normal logic. If this method chooses to return + * a ResponseBuilder, then the Pusher will attach any extra metadata it has to the Response and return + * the response built from the Builder without attempting to process the results of the query. *

- * Even if this method returns null, close() should still be called on this object. + * In all cases, {@link #close()} should be called on this object. * * @return QueryResponse or null if no more work to do. */ @Nullable - QueryResponse start(HttpServletResponse response); + Response.ResponseBuilder start(); + + /** + * Gets the results of running the query. {@link #start} must be called before this method is called. + * + * @return the results of running the query as preparted by the {@link #start()} method + */ + QueryResponse getQueryResponse(); Writer makeWriter(OutputStream out) throws IOException; @@ -301,9 +361,7 @@ public abstract class QueryResultPusher * Initializes the response. This is done lazily so that we can put various metadata that we only get once * we have some of the response stream into the result. *

- * This is called once for each result object, but should only actually happen once. - * - * @return boolean if initialization occurred. False most of the team because initialization only happens once. + * It is okay for this to be called multiple times. */ public void initialize() { @@ -332,7 +390,7 @@ public abstract class QueryResultPusher ); } catch (JsonProcessingException e) { - QueryResource.log.info(e, "Problem serializing to JSON!?"); + log.info(e, "Problem serializing to JSON!?"); serializationResult = new ResponseContext.SerializationResult("Could not serialize", "Could not serialize"); } @@ -343,7 +401,7 @@ public abstract class QueryResultPusher serializationResult.getFullResult() ); if (responseContextConfig.shouldFailOnTruncatedResponseContext()) { - QueryResource.log.error(logToPrint); + log.error(logToPrint); throw new QueryInterruptedException( new TruncatedResponseContextException( "Serialized response context exceeds the max size[%s]", @@ -352,12 +410,12 @@ public abstract class QueryResultPusher selfNode.getHostAndPortToUse() ); } else { - QueryResource.log.warn(logToPrint); + log.warn(logToPrint); } } response.setHeader(QueryResource.HEADER_RESPONSE_CONTEXT, serializationResult.getResult()); - response.setHeader("Content-Type", contentType.toString()); + response.setContentType(contentType.toString()); try { out = new CountingOutputStream(response.getOutputStream()); @@ -379,6 +437,7 @@ public abstract class QueryResultPusher } @Override + @Nullable public Response accumulate(Response retVal, Object in) { if (!initialized) { diff --git a/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java b/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java index 5ec56983649..a87f21c1a2e 100644 --- a/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java +++ b/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java @@ -20,11 +20,12 @@ package org.apache.druid.server.security; import com.fasterxml.jackson.databind.ObjectMapper; -import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.server.DruidNode; +import org.apache.druid.server.QueryResource; import javax.servlet.Filter; import javax.servlet.FilterChain; @@ -83,12 +84,15 @@ public class PreResponseAuthorizationCheckFilter implements Filter filterChain.doFilter(servletRequest, servletResponse); Boolean authInfoChecked = (Boolean) servletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED); - if (authInfoChecked == null && statusIsSuccess(response.getStatus())) { + if (authInfoChecked == null && statusShouldBeHidden(response.getStatus())) { // Note: rather than throwing an exception here, it would be nice to blank out the original response // since the request didn't have any authorization checks performed. However, this breaks proxying // (e.g. OverlordServletProxy), so this is not implemented for now. handleAuthorizationCheckError( - "Request did not have an authorization check performed.", + StringUtils.format( + "Request did not have an authorization check performed, original response status[%s].", + response.getStatus() + ), request, response ); @@ -136,7 +140,6 @@ public class PreResponseAuthorizationCheckFilter implements Filter OutputStream out = response.getOutputStream(); sendJsonError(response, HttpServletResponse.SC_UNAUTHORIZED, jsonMapper.writeValueAsString(unauthorizedError), out); out.close(); - return; } private void handleAuthorizationCheckError( @@ -145,19 +148,21 @@ public class PreResponseAuthorizationCheckFilter implements Filter HttpServletResponse servletResponse ) { + final String queryId = servletResponse.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER); + // Send out an alert so there's a centralized collection point for seeing errors of this nature log.makeAlert(errorMsg) .addData("uri", servletRequest.getRequestURI()) .addData("method", servletRequest.getMethod()) .addData("remoteAddr", servletRequest.getRemoteAddr()) .addData("remoteHost", servletRequest.getRemoteHost()) + .addData("queryId", queryId) .emit(); - if (servletResponse.isCommitted()) { - throw new ISE(errorMsg); - } else { + if (!servletResponse.isCommitted()) { try { - servletResponse.sendError(HttpServletResponse.SC_FORBIDDEN); + servletResponse.reset(); + servletResponse.setStatus(HttpServletResponse.SC_FORBIDDEN); } catch (Exception e) { throw new RuntimeException(e); @@ -165,9 +170,23 @@ public class PreResponseAuthorizationCheckFilter implements Filter } } - private static boolean statusIsSuccess(int status) + private static boolean statusShouldBeHidden(int status) { - return 200 <= status && status < 300; + // We allow 404s (not found) to not be rewritten to forbidden because consistently returning 404s is a way to leak + // less information when something wasn't able to be done anyway. I.e. if we pretend that the thing didn't exist + // when the authorization fails, then there is no information about whether the thing existed. If we return + // a 403 when authorization fails and a 404 when authorization succeeds, but it doesn't exist, then we have + // leaked that it could maybe exist, if the authentication credentials were good. + // + // We also allow 307s (temporary redirect) to not be hidden as they are used to redirect to the leader. + switch (status) { + case HttpServletResponse.SC_FORBIDDEN: + case HttpServletResponse.SC_NOT_FOUND: + case HttpServletResponse.SC_TEMPORARY_REDIRECT: + return false; + default: + return true; + } } public static void sendJsonError(HttpServletResponse resp, int error, String errorJson, OutputStream outputStream) diff --git a/server/src/test/java/org/apache/druid/server/QueryResourceTest.java b/server/src/test/java/org/apache/druid/server/QueryResourceTest.java index fc739008958..61b452f2cb6 100644 --- a/server/src/test/java/org/apache/druid/server/QueryResourceTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryResourceTest.java @@ -26,7 +26,6 @@ import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.inject.Injector; @@ -88,7 +87,9 @@ import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; +import javax.ws.rs.core.StreamingOutput; import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collection; @@ -396,7 +397,7 @@ public class QueryResourceTest final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); //since accept header is null, the response content type should be same as the value of 'Content-Type' header - Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type"))); + Assert.assertEquals(MediaType.APPLICATION_JSON, response.getContentType()); } @Test @@ -409,7 +410,7 @@ public class QueryResourceTest Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); //since accept header is empty, the response content type should be same as the value of 'Content-Type' header - Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type"))); + Assert.assertEquals(MediaType.APPLICATION_JSON, response.getContentType()); } @Test @@ -424,10 +425,7 @@ public class QueryResourceTest Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); // Content-Type in response should be Smile - Assert.assertEquals( - SmileMediaTypes.APPLICATION_JACKSON_SMILE, - Iterables.getOnlyElement(response.headers.get("Content-Type")) - ); + Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType()); } @Test @@ -447,10 +445,7 @@ public class QueryResourceTest Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); // Content-Type in response should be Smile - Assert.assertEquals( - SmileMediaTypes.APPLICATION_JACKSON_SMILE, - Iterables.getOnlyElement(response.headers.get("Content-Type")) - ); + Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType()); } @Test @@ -469,10 +464,7 @@ public class QueryResourceTest Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); // Content-Type in response should default to Content-Type from request - Assert.assertEquals( - SmileMediaTypes.APPLICATION_JACKSON_SMILE, - Iterables.getOnlyElement(response.headers.get("Content-Type")) - ); + Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType()); } @Test @@ -643,13 +635,16 @@ public class QueryResourceTest ); expectPermissiveHappyPathAuth(); - final MockHttpServletResponse response = expectAsyncRequestFlow( + final Response response = expectSynchronousRequestFlow( testServletRequest, SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8), timeoutQueryResource ); Assert.assertEquals(QueryTimeoutException.STATUS_CODE, response.getStatus()); - QueryTimeoutException ex = jsonMapper.readValue(response.baos.toByteArray(), QueryTimeoutException.class); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ((StreamingOutput) response.getEntity()).write(baos); + QueryTimeoutException ex = jsonMapper.readValue(baos.toByteArray(), QueryTimeoutException.class); Assert.assertEquals("Query Timed Out!", ex.getMessage()); Assert.assertEquals(QueryException.QUERY_TIMEOUT_ERROR_CODE, ex.getErrorCode()); Assert.assertEquals(1, timeoutQueryResource.getTimedOutQueryCount()); @@ -892,25 +887,28 @@ public class QueryResourceTest ); createScheduledQueryResource(laningScheduler, Collections.emptyList(), ImmutableList.of(waitTwoScheduled)); - assertResponseAndCountdownOrBlockForever( + assertAsyncResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY, waitAllFinished, response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) ); - assertResponseAndCountdownOrBlockForever( + assertAsyncResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY, waitAllFinished, response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) ); waitTwoScheduled.await(); - assertResponseAndCountdownOrBlockForever( + assertSynchronousResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY, waitAllFinished, response -> { Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); QueryCapacityExceededException ex; + try { - ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ((StreamingOutput) response.getEntity()).write(baos); + ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class); } catch (IOException e) { throw new RuntimeException(e); @@ -938,20 +936,22 @@ public class QueryResourceTest createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled)); - assertResponseAndCountdownOrBlockForever( + assertAsyncResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY, waitAllFinished, response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) ); waitOneScheduled.await(); - assertResponseAndCountdownOrBlockForever( + assertSynchronousResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY, waitAllFinished, response -> { Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); QueryCapacityExceededException ex; try { - ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ((StreamingOutput) response.getEntity()).write(baos); + ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class); } catch (IOException e) { throw new RuntimeException(e); @@ -965,7 +965,7 @@ public class QueryResourceTest } ); waitTwoStarted.await(); - assertResponseAndCountdownOrBlockForever( + assertAsyncResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY, waitAllFinished, response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) @@ -990,20 +990,22 @@ public class QueryResourceTest createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled)); - assertResponseAndCountdownOrBlockForever( + assertAsyncResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY, waitAllFinished, response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) ); waitOneScheduled.await(); - assertResponseAndCountdownOrBlockForever( + assertSynchronousResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY, waitAllFinished, response -> { Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); QueryCapacityExceededException ex; try { - ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ((StreamingOutput) response.getEntity()).write(baos); + ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class); } catch (IOException e) { throw new RuntimeException(e); @@ -1016,7 +1018,7 @@ public class QueryResourceTest } ); waitTwoStarted.await(); - assertResponseAndCountdownOrBlockForever( + assertAsyncResponseAndCountdownOrBlockForever( SIMPLE_TIMESERIES_QUERY_SMALLISH_INTERVAL, waitAllFinished, response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) @@ -1085,7 +1087,7 @@ public class QueryResourceTest ); } - private void assertResponseAndCountdownOrBlockForever( + private void assertAsyncResponseAndCountdownOrBlockForever( String query, CountDownLatch done, Consumer asserts @@ -1146,4 +1148,36 @@ public class QueryResourceTest )); return response; } + + private void assertSynchronousResponseAndCountdownOrBlockForever( + String query, + CountDownLatch done, + Consumer asserts + ) + { + Executors.newSingleThreadExecutor().submit(() -> { + try { + asserts.accept( + expectSynchronousRequestFlow( + testServletRequest.mimic(), + query.getBytes(StandardCharsets.UTF_8), + queryResource + ) + ); + } + catch (IOException e) { + throw new RuntimeException(e); + } + done.countDown(); + }); + } + + private Response expectSynchronousRequestFlow( + MockHttpServletRequest req, + byte[] bytes, + QueryResource queryResource + ) throws IOException + { + return queryResource.doPost(new ByteArrayInputStream(bytes), null, req); + } } diff --git a/server/src/test/java/org/apache/druid/server/http/security/PreResponseAuthorizationCheckFilterTest.java b/server/src/test/java/org/apache/druid/server/http/security/PreResponseAuthorizationCheckFilterTest.java index 92b08923c71..b8b2f21fc44 100644 --- a/server/src/test/java/org/apache/druid/server/http/security/PreResponseAuthorizationCheckFilterTest.java +++ b/server/src/test/java/org/apache/druid/server/http/security/PreResponseAuthorizationCheckFilterTest.java @@ -20,141 +20,155 @@ package org.apache.druid.server.http.security; import org.apache.druid.jackson.DefaultObjectMapper; -import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.emitter.EmittingLogger; -import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.server.metrics.NoopServiceEmitter; +import org.apache.druid.server.mocks.MockHttpServletRequest; +import org.apache.druid.server.mocks.MockHttpServletResponse; import org.apache.druid.server.security.AllowAllAuthenticator; import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.Authenticator; import org.apache.druid.server.security.PreResponseAuthorizationCheckFilter; -import org.easymock.EasyMock; -import org.junit.Rule; +import org.junit.Assert; import org.junit.Test; -import org.junit.rules.ExpectedException; -import javax.servlet.FilterChain; -import javax.servlet.ServletOutputStream; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; +import javax.servlet.ServletException; +import java.io.IOException; import java.util.Collections; import java.util.List; public class PreResponseAuthorizationCheckFilterTest { - private static List authenticators = Collections.singletonList(new AllowAllAuthenticator()); - - @Rule - public ExpectedException expectedException = ExpectedException.none(); + private static final List authenticators = Collections.singletonList(new AllowAllAuthenticator()); @Test public void testValidRequest() throws Exception { AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); - HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); - HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); - FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class); - ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); + MockHttpServletRequest req = new MockHttpServletRequest(); + MockHttpServletResponse resp = new MockHttpServletResponse(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(true).once(); - EasyMock.replay(req, resp, filterChain, outputStream); + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); + req.attributes.put(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( authenticators, new DefaultObjectMapper() ); - filter.doFilter(req, resp, filterChain); - EasyMock.verify(req, resp, filterChain, outputStream); + filter.doFilter(req, resp, (request, response) -> { + }); } @Test public void testAuthenticationFailedRequest() throws Exception { - HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); - HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); - FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class); - ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); - - EasyMock.expect(resp.getOutputStream()).andReturn(outputStream).once(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(null).once(); - resp.setStatus(401); - EasyMock.expectLastCall().once(); - resp.setContentType("application/json"); - EasyMock.expectLastCall().once(); - resp.setCharacterEncoding("UTF-8"); - EasyMock.expectLastCall().once(); - EasyMock.replay(req, resp, filterChain, outputStream); + MockHttpServletRequest req = new MockHttpServletRequest(); + MockHttpServletResponse resp = new MockHttpServletResponse(); PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( authenticators, new DefaultObjectMapper() ); - filter.doFilter(req, resp, filterChain); - EasyMock.verify(req, resp, filterChain, outputStream); + filter.doFilter(req, resp, (request, response) -> { + }); + + Assert.assertEquals(401, resp.getStatus()); + Assert.assertEquals("application/json", resp.getContentType()); + Assert.assertEquals("UTF-8", resp.getCharacterEncoding()); } @Test - public void testMissingAuthorizationCheck() throws Exception + public void testMissingAuthorizationCheckAndNotCommitted() throws ServletException, IOException { - EmittingLogger.registerEmitter(EasyMock.createNiceMock(ServiceEmitter.class)); - - expectedException.expect(ISE.class); - expectedException.expectMessage("Request did not have an authorization check performed."); + EmittingLogger.registerEmitter(new NoopServiceEmitter()); AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); - HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); - HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); - FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class); - ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); + MockHttpServletRequest req = new MockHttpServletRequest(); + req.requestUri = "uri"; + req.method = "GET"; + req.remoteAddr = "1.2.3.4"; + req.remoteHost = "aHost"; - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(null).once(); - EasyMock.expect(resp.getStatus()).andReturn(200).once(); - EasyMock.expect(req.getRequestURI()).andReturn("uri").once(); - EasyMock.expect(req.getMethod()).andReturn("GET").once(); - EasyMock.expect(req.getRemoteAddr()).andReturn("1.2.3.4").once(); - EasyMock.expect(req.getRemoteHost()).andReturn("ahostname").once(); - EasyMock.expect(resp.isCommitted()).andReturn(true).once(); + MockHttpServletResponse resp = new MockHttpServletResponse(); + resp.setStatus(200); + + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); + + PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( + authenticators, + new DefaultObjectMapper() + ); + filter.doFilter(req, resp, (request, response) -> { + }); + + Assert.assertEquals(403, resp.getStatus()); + } + + @Test + public void testMissingAuthorizationCheckWithForbidden() throws Exception + { + EmittingLogger.registerEmitter(new NoopServiceEmitter()); + AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); + + MockHttpServletRequest req = new MockHttpServletRequest(); + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); + + MockHttpServletResponse resp = new MockHttpServletResponse(); resp.setStatus(403); - EasyMock.expectLastCall().once(); - resp.setContentType("application/json"); - EasyMock.expectLastCall().once(); - resp.setCharacterEncoding("UTF-8"); - EasyMock.expectLastCall().once(); - EasyMock.replay(req, resp, filterChain, outputStream); PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( authenticators, new DefaultObjectMapper() ); - filter.doFilter(req, resp, filterChain); - EasyMock.verify(req, resp, filterChain, outputStream); + filter.doFilter(req, resp, (request, response) -> { + }); + + Assert.assertEquals(403, resp.getStatus()); } @Test - public void testMissingAuthorizationCheckWithError() throws Exception + public void testMissingAuthorizationCheckWith404Keeps404() throws Exception { - EmittingLogger.registerEmitter(EasyMock.createNiceMock(ServiceEmitter.class)); + EmittingLogger.registerEmitter(new NoopServiceEmitter()); AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); - HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); - HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); - FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class); - ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); + MockHttpServletRequest req = new MockHttpServletRequest(); + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(null).once(); - EasyMock.expect(resp.getStatus()).andReturn(404).once(); - EasyMock.replay(req, resp, filterChain, outputStream); + MockHttpServletResponse resp = new MockHttpServletResponse(); + resp.setStatus(404); PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( authenticators, new DefaultObjectMapper() ); - filter.doFilter(req, resp, filterChain); - EasyMock.verify(req, resp, filterChain, outputStream); + filter.doFilter(req, resp, (request, response) -> { + }); + + Assert.assertEquals(404, resp.getStatus()); + } + + @Test + public void testMissingAuthorizationCheckWith307Keeps307() throws Exception + { + EmittingLogger.registerEmitter(new NoopServiceEmitter()); + AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); + + MockHttpServletRequest req = new MockHttpServletRequest(); + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); + + MockHttpServletResponse resp = new MockHttpServletResponse(); + resp.setStatus(307); + + PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( + authenticators, + new DefaultObjectMapper() + ); + filter.doFilter(req, resp, (request, response) -> { + }); + + Assert.assertEquals(307, resp.getStatus()); } } diff --git a/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java index 85f1333244b..0582f5871b1 100644 --- a/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java +++ b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java @@ -54,9 +54,11 @@ import java.util.function.Supplier; */ public class MockHttpServletRequest implements HttpServletRequest { + public String requestUri = null; public String method = null; public String contentType = null; public String remoteAddr = null; + public String remoteHost = null; public LinkedHashMap headers = new LinkedHashMap<>(); public LinkedHashMap attributes = new LinkedHashMap<>(); @@ -110,7 +112,7 @@ public class MockHttpServletRequest implements HttpServletRequest @Override public String getMethod() { - return method; + return unsupportedIfNull(method); } @Override @@ -164,7 +166,7 @@ public class MockHttpServletRequest implements HttpServletRequest @Override public String getRequestURI() { - throw new UnsupportedOperationException(); + return unsupportedIfNull(requestUri); } @Override @@ -296,7 +298,7 @@ public class MockHttpServletRequest implements HttpServletRequest @Override public String getContentType() { - return contentType; + return unsupportedIfNull(contentType); } @Override @@ -362,13 +364,13 @@ public class MockHttpServletRequest implements HttpServletRequest @Override public String getRemoteAddr() { - return remoteAddr; + return unsupportedIfNull(remoteAddr); } @Override public String getRemoteHost() { - throw new UnsupportedOperationException(); + return unsupportedIfNull(remoteHost); } @Override @@ -486,6 +488,9 @@ public class MockHttpServletRequest implements HttpServletRequest @Override public AsyncContext getAsyncContext() { + if (currAsyncContext == null) { + throw new IllegalStateException("Must be put into Async mode before async context can be gottendid"); + } return currAsyncContext; } @@ -514,4 +519,13 @@ public class MockHttpServletRequest implements HttpServletRequest return retVal; } + + private T unsupportedIfNull(T obj) + { + if (obj == null) { + throw new UnsupportedOperationException(); + } else { + return obj; + } + } } diff --git a/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java index b484cb56afd..62e6ca4a266 100644 --- a/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java +++ b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java @@ -22,12 +22,12 @@ package org.apache.druid.server.mocks; import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; -import javax.annotation.Nonnull; import javax.annotation.Nullable; import javax.servlet.ServletOutputStream; import javax.servlet.WriteListener; import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletResponse; +import javax.validation.constraints.NotNull; import java.io.ByteArrayOutputStream; import java.io.PrintWriter; import java.util.ArrayList; @@ -63,6 +63,21 @@ public class MockHttpServletResponse implements HttpServletResponse private int statusCode; private String contentType; + private String characterEncoding; + + @Override + public void reset() + { + if (isCommitted()) { + throw new IllegalStateException("Cannot reset a committed ServletResponse"); + } + + headers.clear(); + statusCode = 0; + contentType = null; + characterEncoding = null; + } + @Override public void addCookie(Cookie cookie) @@ -198,7 +213,7 @@ public class MockHttpServletResponse implements HttpServletResponse @Override public String getCharacterEncoding() { - throw new UnsupportedOperationException(); + return characterEncoding; } @Override @@ -231,13 +246,13 @@ public class MockHttpServletResponse implements HttpServletResponse } @Override - public void write(@Nonnull byte[] b) + public void write(@NotNull byte[] b) { baos.write(b, 0, b.length); } @Override - public void write(@Nonnull byte[] b, int off, int len) + public void write(@NotNull byte[] b, int off, int len) { baos.write(b, off, len); } @@ -253,7 +268,7 @@ public class MockHttpServletResponse implements HttpServletResponse @Override public void setCharacterEncoding(String charset) { - throw new UnsupportedOperationException(); + characterEncoding = charset; } @Override @@ -298,18 +313,19 @@ public class MockHttpServletResponse implements HttpServletResponse throw new UnsupportedOperationException(); } + public void forceCommitted() + { + if (!isCommitted()) { + baos.write(1234); + } + } + @Override public boolean isCommitted() { return baos.size() > 0; } - @Override - public void reset() - { - throw new UnsupportedOperationException(); - } - @Override public void setLocale(Locale loc) { diff --git a/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java b/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java index 60a2b89bf8f..41dcf6cc02f 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java @@ -61,22 +61,27 @@ public class SqlPlanningException extends BadQueryException public SqlPlanningException(SqlParseException e) { - this(PlanningError.SQL_PARSE_ERROR, e.getMessage()); + this(e, PlanningError.SQL_PARSE_ERROR, e.getMessage()); } public SqlPlanningException(ValidationException e) { - this(PlanningError.VALIDATION_ERROR, e.getMessage()); + this(e, PlanningError.VALIDATION_ERROR, e.getMessage()); } public SqlPlanningException(CalciteContextException e) { - this(PlanningError.VALIDATION_ERROR, e.getMessage()); + this(e, PlanningError.VALIDATION_ERROR, e.getMessage()); } public SqlPlanningException(PlanningError planningError, String errorMessage) { - this(planningError.errorCode, errorMessage, planningError.errorClass); + this(null, planningError, errorMessage); + } + + public SqlPlanningException(Throwable cause, PlanningError planningError, String errorMessage) + { + this(cause, planningError.errorCode, errorMessage, planningError.errorClass); } @JsonCreator @@ -86,6 +91,17 @@ public class SqlPlanningException extends BadQueryException @JsonProperty("errorClass") String errorClass ) { - super(errorCode, errorMessage, errorClass); + this(null, errorCode, errorMessage, errorClass); } + + private SqlPlanningException( + Throwable cause, + String errorCode, + String errorMessage, + String errorClass + ) + { + super(cause, errorCode, errorMessage, errorClass, null); + } + } diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index c98b091b9c7..dad391e6bf0 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -48,9 +48,7 @@ import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.SqlStatementFactory; import javax.annotation.Nullable; -import javax.servlet.AsyncContext; import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.POST; @@ -63,7 +61,9 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; import java.io.IOException; import java.io.OutputStream; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -121,22 +121,8 @@ public class SqlResource try { Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId)); - // We use an async context not because we are actually going to run this async, but because we want to delay - // the decision of what the response code should be until we have gotten the first few data points to return. - // Returning a Response object from this point forward requires that object to know the status code, which we - // don't actually know until we are in the accumulator, but if we try to return a Response object from the - // accumulator, we cannot properly stream results back, because the accumulator won't release control of the - // Response until it has consumed the underlying Sequence. - final AsyncContext asyncContext = req.startAsync(); - - try { - QueryResultPusher pusher = new SqlResourceQueryResultPusher(asyncContext, sqlQueryId, stmt, sqlQuery); - pusher.push(); - return null; - } - finally { - asyncContext.complete(); - } + QueryResultPusher pusher = makePusher(req, stmt, sqlQuery); + return pusher.push(); } finally { Thread.currentThread().setName(currThreadName); @@ -213,27 +199,43 @@ public class SqlResource } } + private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStatement stmt, SqlQuery sqlQuery) + { + final String sqlQueryId = stmt.sqlQueryId(); + Map headers = new LinkedHashMap<>(); + headers.put(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId); + + if (sqlQuery.includeHeader()) { + headers.put(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); + } + + return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, headers); + } + private class SqlResourceQueryResultPusher extends QueryResultPusher { + private final String sqlQueryId; private final HttpStatement stmt; private final SqlQuery sqlQuery; public SqlResourceQueryResultPusher( - AsyncContext asyncContext, + HttpServletRequest req, String sqlQueryId, HttpStatement stmt, - SqlQuery sqlQuery + SqlQuery sqlQuery, + Map headers ) { super( - (HttpServletResponse) asyncContext.getResponse(), + req, SqlResource.this.jsonMapper, SqlResource.this.responseContextConfig, SqlResource.this.selfNode, SqlResource.QUERY_METRIC_COUNTER, sqlQueryId, - MediaType.APPLICATION_JSON_TYPE + MediaType.APPLICATION_JSON_TYPE, + headers ); this.sqlQueryId = sqlQueryId; this.stmt = stmt; @@ -245,19 +247,17 @@ public class SqlResource { return new ResultsWriter() { + private QueryResponse queryResponse; private ResultSet thePlan; @Override @Nullable - @SuppressWarnings({"unchecked", "rawtypes"}) - public QueryResponse start(HttpServletResponse response) + public Response.ResponseBuilder start() { - response.setHeader(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId); - - final QueryResponse retVal; try { thePlan = stmt.plan(); - retVal = thePlan.run(); + queryResponse = thePlan.run(); + return null; } catch (RelOptPlanner.CannotPlanException e) { throw new SqlPlanningException( @@ -276,12 +276,13 @@ public class SqlResource // doesn't implement org.apache.druid.common.exception.SanitizableException. throw new QueryInterruptedException(e); } + } - if (sqlQuery.includeHeader()) { - response.setHeader(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); - } - - return (QueryResponse) retVal; + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public QueryResponse getQueryResponse() + { + return (QueryResponse) queryResponse; } @Override @@ -343,6 +344,11 @@ public class SqlResource @Override public void recordFailure(Exception e) { + if (sqlQuery.queryContext().isDebug()) { + log.warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId); + } else { + log.noStackTrace().warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId); + } stmt.reporter().failed(e); } diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index cb9dd0e07aa..9baba52555e 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -22,7 +22,6 @@ package org.apache.druid.sql.http; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; -import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -39,6 +38,7 @@ import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.NonnullPair; import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.guava.LazySequence; @@ -61,6 +61,7 @@ import org.apache.druid.query.ResourceLimitExceededException; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.server.DruidNode; +import org.apache.druid.server.QueryResource; import org.apache.druid.server.QueryResponse; import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryStackTests; @@ -109,9 +110,14 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.HttpHeaders; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; +import javax.ws.rs.core.StreamingOutput; +import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.AbstractList; import java.util.ArrayList; @@ -128,6 +134,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; public class SqlResourceTest extends CalciteTestBase @@ -175,7 +182,7 @@ public class SqlResourceTest extends CalciteTestBase private final SettableSupplier responseContextSupplier = new SettableSupplier<>(); private Consumer onExecute = NULL_ACTION; - private boolean sleep; + private Supplier schedulerBaggage = () -> null; @Before public void setUp() throws Exception @@ -195,15 +202,8 @@ public class SqlResourceTest extends CalciteTestBase { return super.run( query, - new LazySequence(() -> { - if (sleep) { - try { - // pretend to be a query that is waiting on results - Thread.sleep(500); - } - catch (InterruptedException ignored) { - } - } + new LazySequence<>(() -> { + schedulerBaggage.get(); return resultSequence; }) ); @@ -329,7 +329,7 @@ public class SqlResourceTest extends CalciteTestBase public void testUnauthorized() { try { - postForResponse( + postForAsyncResponse( createSimpleQueryWithId("id", "select count(*) from forbiddenDatasource"), request() ); @@ -380,18 +380,17 @@ public class SqlResourceTest extends CalciteTestBase mockRespContext.put(ResponseContext.Keys.instance().keyOf("uncoveredIntervalsOverflowed"), "true"); responseContextSupplier.set(mockRespContext); - final MockHttpServletResponse response = postForResponse(sqlQuery, makeRegularUserReq()); + final MockHttpServletResponse response = postForAsyncResponse(sqlQuery, makeRegularUserReq()); - Map responseContext = JSON_MAPPER.readValue( - Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")), - Map.class - ); Assert.assertEquals( ImmutableMap.of( "uncoveredIntervals", "2030-01-01/78149827981274-01-01", "uncoveredIntervalsOverflowed", "true" ), - responseContext + JSON_MAPPER.readValue( + Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")), + Map.class + ) ); Object results = JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class); @@ -714,6 +713,7 @@ public class SqlResourceTest extends CalciteTestBase } @Test + @SuppressWarnings("rawtypes") public void testArrayResultFormatWithHeader() throws Exception { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; @@ -725,7 +725,7 @@ public class SqlResourceTest extends CalciteTestBase Arrays.asList("2000-01-02T00:00:00.000Z", "10.1", nullStr, "[\"b\",\"c\"]", 1, 2.0, 2.0, hllStr, nullStr) }; - MockHttpServletResponse response = postForResponse( + MockHttpServletResponse response = postForAsyncResponse( new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), req.mimic() ); @@ -745,7 +745,7 @@ public class SqlResourceTest extends CalciteTestBase JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class) ); - MockHttpServletResponse responseNoSqlTypesHeader = postForResponse( + MockHttpServletResponse responseNoSqlTypesHeader = postForAsyncResponse( new SqlQuery(query, ResultFormat.ARRAY, true, true, false, null, null), req.mimic() ); @@ -764,7 +764,7 @@ public class SqlResourceTest extends CalciteTestBase JSON_MAPPER.readValue(responseNoSqlTypesHeader.baos.toByteArray(), Object.class) ); - MockHttpServletResponse responseNoTypesHeader = postForResponse( + MockHttpServletResponse responseNoTypesHeader = postForAsyncResponse( new SqlQuery(query, ResultFormat.ARRAY, true, false, true, null, null), req.mimic() ); @@ -783,7 +783,7 @@ public class SqlResourceTest extends CalciteTestBase JSON_MAPPER.readValue(responseNoTypesHeader.baos.toByteArray(), Object.class) ); - MockHttpServletResponse responseNoTypes = postForResponse( + MockHttpServletResponse responseNoTypes = postForAsyncResponse( new SqlQuery(query, ResultFormat.ARRAY, true, false, false, null, null), req.mimic() ); @@ -801,7 +801,7 @@ public class SqlResourceTest extends CalciteTestBase JSON_MAPPER.readValue(responseNoTypes.baos.toByteArray(), Object.class) ); - MockHttpServletResponse responseNoHeader = postForResponse( + MockHttpServletResponse responseNoHeader = postForAsyncResponse( new SqlQuery(query, ResultFormat.ARRAY, false, false, false, null, null), req.mimic() ); @@ -821,7 +821,7 @@ public class SqlResourceTest extends CalciteTestBase // Test a query that returns null header for some of the columns final String query = "SELECT (1, 2) FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1"; - MockHttpServletResponse response = postForResponse( + MockHttpServletResponse response = postForAsyncResponse( new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), req ); @@ -971,12 +971,10 @@ public class SqlResourceTest extends CalciteTestBase { final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; - final Function, Map> transformer = m -> { - return Maps.transformEntries( - m, - (k, v) -> "EXPR$8".equals(k) || ("dim2".equals(k) && v.toString().isEmpty()) ? nullStr : v - ); - }; + final Function, Map> transformer = m -> Maps.transformEntries( + m, + (k, v) -> "EXPR$8".equals(k) || ("dim2".equals(k) && v.toString().isEmpty()) ? nullStr : v + ); Assert.assertEquals( ImmutableList.of( @@ -1336,9 +1334,7 @@ public class SqlResourceTest extends CalciteTestBase @Test public void testCannotParse() throws Exception { - final QueryException exception = doPost( - createSimpleQueryWithId("id", "FROM druid.foo") - ).lhs; + QueryException exception = postSyncForException("FROM druid.foo", Status.BAD_REQUEST.getStatusCode()); Assert.assertNotNull(exception); Assert.assertEquals(PlanningError.SQL_PARSE_ERROR.getErrorCode(), exception.getErrorCode()); @@ -1351,9 +1347,7 @@ public class SqlResourceTest extends CalciteTestBase @Test public void testCannotValidate() throws Exception { - final QueryException exception = doPost( - createSimpleQueryWithId("id", "SELECT dim4 FROM druid.foo") - ).lhs; + QueryException exception = postSyncForException("SELECT dim4 FROM druid.foo", Status.BAD_REQUEST.getStatusCode()); Assert.assertNotNull(exception); Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode()); @@ -1367,10 +1361,10 @@ public class SqlResourceTest extends CalciteTestBase public void testCannotConvert() throws Exception { // SELECT + ORDER unsupported - final QueryException exception = doPost( - createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1") - ).lhs; + final SqlQuery unsupportedQuery = createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1"); + QueryException exception = postSyncForException(unsupportedQuery, Status.BAD_REQUEST.getStatusCode()); + Assert.assertTrue((Boolean) req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)); Assert.assertNotNull(exception); Assert.assertEquals("SQL query is unsupported", exception.getErrorCode()); Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorClass(), exception.getErrorClass()); @@ -1392,9 +1386,10 @@ public class SqlResourceTest extends CalciteTestBase public void testCannotConvert_UnsupportedSQLQueryException() throws Exception { // max(string) unsupported - final QueryException exception = doPost( - createSimpleQueryWithId("id", "SELECT max(dim1) FROM druid.foo") - ).lhs; + QueryException exception = postSyncForException( + "SELECT max(dim1) FROM druid.foo", + Status.BAD_REQUEST.getStatusCode() + ); Assert.assertNotNull(exception); Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorCode(), exception.getErrorCode()); @@ -1442,7 +1437,7 @@ public class SqlResourceTest extends CalciteTestBase { String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final QueryException exception = doPost( + QueryException exception = postSyncForException( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1451,8 +1446,9 @@ public class SqlResourceTest extends CalciteTestBase false, ImmutableMap.of(BaseQuery.SQL_QUERY_ID, "id"), null - ) - ).lhs; + ), + 501 + ); Assert.assertNotNull(exception); Assert.assertEquals(QueryException.QUERY_UNSUPPORTED_ERROR_CODE, exception.getErrorCode()); @@ -1466,7 +1462,7 @@ public class SqlResourceTest extends CalciteTestBase String queryId = "id123"; String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final MockHttpServletResponse response = postForResponse( + final Response response = postForSyncResponse( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1478,11 +1474,12 @@ public class SqlResourceTest extends CalciteTestBase ), req ); - Assert.assertNotEquals(200, response.getStatus()); - Assert.assertEquals( - queryId, - Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)) - ); + + // This is checked in the common method that returns the response, but checking it again just protects + // from changes there breaking the checks, so doesn't hurt. + assertStatusAndCommonHeaders(response, 501); + Assert.assertEquals(queryId, getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER)); + Assert.assertEquals(queryId, getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); } @Test @@ -1490,7 +1487,7 @@ public class SqlResourceTest extends CalciteTestBase { String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final MockHttpServletResponse response = postForResponse( + final Response response = postForSyncResponse( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1502,10 +1499,10 @@ public class SqlResourceTest extends CalciteTestBase ), req ); - Assert.assertNotEquals(200, response.getStatus()); - Assert.assertFalse( - Strings.isNullOrEmpty(Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER))) - ); + + // This is checked in the common method that returns the response, but checking it again just protects + // from changes there breaking the checks, so doesn't hurt. + assertStatusAndCommonHeaders(response, 501); } @Test @@ -1536,7 +1533,7 @@ public class SqlResourceTest extends CalciteTestBase String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final QueryException exception = doPost( + QueryException exception = postSyncForException( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1545,8 +1542,9 @@ public class SqlResourceTest extends CalciteTestBase false, ImmutableMap.of("sqlQueryId", "id"), null - ) - ).lhs; + ), + 501 + ); Assert.assertNotNull(exception); Assert.assertNull(exception.getMessage()); @@ -1587,7 +1585,7 @@ public class SqlResourceTest extends CalciteTestBase onExecute = s -> { throw new AssertionError(errorMessage); }; - final QueryException exception = doPost( + QueryException exception = postSyncForException( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1596,8 +1594,9 @@ public class SqlResourceTest extends CalciteTestBase false, ImmutableMap.of("sqlQueryId", "id"), null - ) - ).lhs; + ), + Status.INTERNAL_SERVER_ERROR.getStatusCode() + ); Assert.assertNotNull(exception); Assert.assertNull(exception.getMessage()); @@ -1610,15 +1609,28 @@ public class SqlResourceTest extends CalciteTestBase @Test public void testTooManyRequests() throws Exception { - sleep = true; final int numQueries = 3; + CountDownLatch queriesScheduledLatch = new CountDownLatch(numQueries - 1); + CountDownLatch runQueryLatch = new CountDownLatch(1); + + schedulerBaggage = () -> { + queriesScheduledLatch.countDown(); + try { + runQueryLatch.await(); + } + catch (InterruptedException e) { + throw new RE(e); + } + return null; + }; + final String sqlQueryId = "tooManyRequestsTest"; - List>>>> futures = new ArrayList<>(numQueries); - for (int i = 0; i < numQueries; i++) { + List> futures = new ArrayList<>(numQueries); + for (int i = 0; i < numQueries - 1; i++) { futures.add(executorService.submit(() -> { try { - return doPost( + return postForAsyncResponse( new SqlQuery( "SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", null, @@ -1637,22 +1649,51 @@ public class SqlResourceTest extends CalciteTestBase })); } + queriesScheduledLatch.await(); + schedulerBaggage = () -> null; + futures.add(executorService.submit(() -> { + try { + final Response retVal = postForSyncResponse( + new SqlQuery( + "SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", + null, + false, + false, + false, + ImmutableMap.of("priority", -5, BaseQuery.SQL_QUERY_ID, sqlQueryId), + null + ), + makeRegularUserReq() + ); + runQueryLatch.countDown(); + return retVal; + } + catch (Exception e) { + throw new RuntimeException(e); + } + })); + int success = 0; int limited = 0; for (int i = 0; i < numQueries; i++) { - Pair>> result = futures.get(i).get(); - List> rows = result.rhs; - if (rows != null) { - Assert.assertEquals(ImmutableList.of(ImmutableMap.of("cnt", 6, "TheFoo", "foo")), rows); - success++; - } else { - QueryException interruped = result.lhs; + if (i == 2) { + Response response = (Response) futures.get(i).get(); + assertStatusAndCommonHeaders(response, 429); + QueryException interruped = deserializeResponse(response, QueryException.class); Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, interruped.getErrorCode()); Assert.assertEquals( QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 2), interruped.getMessage() ); limited++; + } else { + MockHttpServletResponse response = (MockHttpServletResponse) futures.get(i).get(); + assertStatusAndCommonHeaders(response, 200); + Assert.assertEquals( + ImmutableList.of(ImmutableMap.of("cnt", 6, "TheFoo", "foo")), + deserializeResponse(response, Object.class) + ); + success++; } } Assert.assertEquals(2, success); @@ -1671,7 +1712,8 @@ public class SqlResourceTest extends CalciteTestBase BaseQuery.SQL_QUERY_ID, sqlQueryId ); - final QueryException timeoutException = doPost( + + QueryException exception = postSyncForException( new SqlQuery( "SELECT CAST(__time AS DATE), dim1, dim2, dim3 FROM druid.foo GROUP by __time, dim1, dim2, dim3 ORDER BY dim2 DESC", ResultFormat.OBJECT, @@ -1680,11 +1722,13 @@ public class SqlResourceTest extends CalciteTestBase false, queryContext, null - ) - ).lhs; - Assert.assertNotNull(timeoutException); - Assert.assertEquals(timeoutException.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE); - Assert.assertEquals(timeoutException.getErrorClass(), QueryTimeoutException.class.getName()); + ), + 504 + ); + + Assert.assertNotNull(exception); + Assert.assertEquals(exception.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE); + Assert.assertEquals(exception.getErrorClass(), QueryTimeoutException.class.getName()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); } @@ -1697,8 +1741,8 @@ public class SqlResourceTest extends CalciteTestBase validateAndAuthorizeLatchSupplier.set(new NonnullPair<>(validateAndAuthorizeLatch, true)); CountDownLatch planLatch = new CountDownLatch(1); planLatchSupplier.set(new NonnullPair<>(planLatch, false)); - Future future = executorService.submit( - () -> postForResponse( + Future future = executorService.submit( + () -> postForSyncResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), makeRegularUserReq() ) @@ -1711,9 +1755,10 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); - MockHttpServletResponse queryResponse = future.get(); - Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus()); - QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class); + Response queryResponse = future.get(); + assertStatusAndCommonHeaders(queryResponse, Status.INTERNAL_SERVER_ERROR.getStatusCode()); + + QueryException exception = deserializeResponse(queryResponse, QueryException.class); Assert.assertEquals("Query cancelled", exception.getErrorCode()); } @@ -1725,8 +1770,8 @@ public class SqlResourceTest extends CalciteTestBase planLatchSupplier.set(new NonnullPair<>(planLatch, true)); CountDownLatch execLatch = new CountDownLatch(1); executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); - Future future = executorService.submit( - () -> postForResponse( + Future future = executorService.submit( + () -> postForSyncResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), makeRegularUserReq() ) @@ -1738,9 +1783,10 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); - MockHttpServletResponse queryResponse = future.get(); - Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus()); - QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class); + Response queryResponse = future.get(); + assertStatusAndCommonHeaders(queryResponse, Status.INTERNAL_SERVER_ERROR.getStatusCode()); + + QueryException exception = deserializeResponse(queryResponse, QueryException.class); Assert.assertEquals("Query cancelled", exception.getErrorCode()); } @@ -1753,7 +1799,7 @@ public class SqlResourceTest extends CalciteTestBase CountDownLatch execLatch = new CountDownLatch(1); executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); Future future = executorService.submit( - () -> postForResponse( + () -> postForAsyncResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), makeRegularUserReq() ) @@ -1778,7 +1824,7 @@ public class SqlResourceTest extends CalciteTestBase CountDownLatch execLatch = new CountDownLatch(1); executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); Future future = executorService.submit( - () -> postForResponse( + () -> postForAsyncResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM forbiddenDatasource"), makeSuperUserReq() ) @@ -1827,23 +1873,16 @@ public class SqlResourceTest extends CalciteTestBase public void testQueryContextKeyNotAllowed() throws Exception { Map queryContext = ImmutableMap.of(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, "all"); - final QueryException queryContextException = doPost( - new SqlQuery( - "SELECT 1337", - ResultFormat.OBJECT, - false, - false, - false, - queryContext, - null - ) - ).lhs; - Assert.assertNotNull(queryContextException); - Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), queryContextException.getErrorCode()); + QueryException exception = postSyncForException( + new SqlQuery("SELECT 1337", ResultFormat.OBJECT, false, false, false, queryContext, null), + Status.BAD_REQUEST.getStatusCode() + ); + + Assert.assertNotNull(exception); + Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode()); MatcherAssert.assertThat( - queryContextException.getMessage(), - CoreMatchers.containsString( - "Cannot execute query with context parameter [sqlInsertSegmentGranularity]") + exception.getMessage(), + CoreMatchers.containsString("Cannot execute query with context parameter [sqlInsertSegmentGranularity]") ); checkSqlRequestLog(false); } @@ -1922,7 +1961,7 @@ public class SqlResourceTest extends CalciteTestBase private Pair doPostRaw(final SqlQuery query, final MockHttpServletRequest req) throws Exception { - MockHttpServletResponse response = postForResponse(query, req); + MockHttpServletResponse response = postForAsyncResponse(query, req); if (response.getStatus() == 200) { return Pair.of(null, new String(response.baos.toByteArray(), StandardCharsets.UTF_8)); @@ -1932,14 +1971,122 @@ public class SqlResourceTest extends CalciteTestBase } @Nonnull - private MockHttpServletResponse postForResponse(SqlQuery query, MockHttpServletRequest req) + private MockHttpServletResponse postForAsyncResponse(SqlQuery query, MockHttpServletRequest req) { MockHttpServletResponse response = MockHttpServletResponse.forRequest(req); + final Object explicitQueryId = query.getContext().get("queryId"); + final Object explicitSqlQueryId = query.getContext().get("sqlQueryId"); Assert.assertNull(resource.doPost(query, req)); + + final Object actualQueryId = response.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER); + final Object actualSqlQueryId = response.getHeader(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER); + validateQueryIds(explicitQueryId, explicitSqlQueryId, actualQueryId, actualSqlQueryId); + return response; } + private void assertStatusAndCommonHeaders(MockHttpServletResponse queryResponse, int statusCode) + { + Assert.assertEquals(statusCode, queryResponse.getStatus()); + Assert.assertEquals("application/json", queryResponse.getContentType()); + Assert.assertNotNull(queryResponse.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER)); + Assert.assertNotNull(queryResponse.getHeader(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); + } + + private T deserializeResponse(MockHttpServletResponse resp, Class clazz) throws IOException + { + return JSON_MAPPER.readValue(resp.baos.toByteArray(), clazz); + } + + private Response postForSyncResponse(SqlQuery query, MockHttpServletRequest req) + { + final Object explicitQueryId = query.getContext().get("queryId"); + final Object explicitSqlQueryId = query.getContext().get("sqlQueryId"); + + final Response response = resource.doPost(query, req); + + final Object actualQueryId = getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER); + final Object actualSqlQueryId = getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER); + + validateQueryIds(explicitQueryId, explicitSqlQueryId, actualQueryId, actualSqlQueryId); + + return response; + } + + private QueryException postSyncForException(String s, int expectedStatus) throws IOException + { + return postSyncForException(createSimpleQueryWithId("id", s), expectedStatus); + } + + private QueryException postSyncForException(SqlQuery query, int expectedStatus) throws IOException + { + final Response response = postForSyncResponse(query, req); + assertStatusAndCommonHeaders(response, expectedStatus); + return deserializeResponse(response, QueryException.class); + } + + private T deserializeResponse(Response resp, Class clazz) throws IOException + { + return JSON_MAPPER.readValue(responseToByteArray(resp), clazz); + } + + private byte[] responseToByteArray(Response resp) throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ((StreamingOutput) resp.getEntity()).write(baos); + return baos.toByteArray(); + } + + private String getContentType(Response resp) + { + return getHeader(resp, HttpHeaders.CONTENT_TYPE).toString(); + } + + @Nullable + private Object getHeader(Response resp, String header) + { + final List objects = resp.getMetadata().get(header); + if (objects == null) { + return null; + } + return Iterables.getOnlyElement(objects); + } + + private void assertStatusAndCommonHeaders(Response queryResponse, int statusCode) + { + Assert.assertEquals(statusCode, queryResponse.getStatus()); + Assert.assertEquals("application/json", getContentType(queryResponse)); + Assert.assertNotNull(getHeader(queryResponse, QueryResource.QUERY_ID_RESPONSE_HEADER)); + Assert.assertNotNull(getHeader(queryResponse, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); + } + + private void validateQueryIds( + Object explicitQueryId, + Object explicitSqlQueryId, + Object actualQueryId, + Object actualSqlQueryId + ) + { + if (explicitQueryId == null) { + if (null != explicitSqlQueryId) { + Assert.assertEquals(explicitSqlQueryId, actualQueryId); + Assert.assertEquals(explicitSqlQueryId, actualSqlQueryId); + } else { + Assert.assertNotNull(actualQueryId); + Assert.assertNotNull(actualSqlQueryId); + } + } else { + if (explicitSqlQueryId == null) { + Assert.assertEquals(explicitQueryId, actualQueryId); + Assert.assertEquals(explicitQueryId, actualSqlQueryId); + } else { + Assert.assertEquals(explicitQueryId, actualQueryId); + Assert.assertEquals(explicitSqlQueryId, actualSqlQueryId); + } + } + } + private MockHttpServletRequest makeSuperUserReq() { return makeExpectedReq(CalciteTests.SUPER_USER_AUTH_RESULT); @@ -1954,6 +2101,7 @@ public class SqlResourceTest extends CalciteTestBase { MockHttpServletRequest req = new MockHttpServletRequest(); req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); + req.remoteAddr = "1.2.3.4"; return req; }