Validate response headers and fix exception logging (#13609)

* Validate response headers and fix exception logging

A class of QueryException were throwing away their
causes making it really hard to determine what's
going wrong when something goes wrong in the SQL
planner specifically.  Fix that and adjust tests
 to do more validation of response headers as well.

We allow 404s and 307s to be returned even without 
authorization validated, but others get converted to 403
This commit is contained in:
imply-cheddar 2023-01-06 07:15:15 +09:00 committed by GitHub
parent 7a7874a952
commit a8ecc48ffe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 760 additions and 386 deletions

View File

@ -152,7 +152,12 @@ public class QueryException extends RuntimeException implements SanitizableExcep
protected QueryException(Throwable cause, String errorCode, String errorClass, String host)
{
super(cause == null ? null : cause.getMessage(), cause);
this(cause, errorCode, cause == null ? null : cause.getMessage(), errorClass, host);
}
protected QueryException(Throwable cause, String errorCode, String errorMessage, String errorClass, String host)
{
super(errorMessage, cause);
this.errorCode = errorCode;
this.errorClass = errorClass;
this.host = host;

View File

@ -95,6 +95,24 @@ public class QueryExceptionTest
expectFailTypeForCode(FailType.USER_ERROR, QueryException.SQL_QUERY_UNSUPPORTED_ERROR_CODE);
}
/**
* This test exists primarily to get branch coverage of the null check on the QueryException constructor.
* The validations done in this test are not actually intended to be set-in-stone or anything.
*/
@Test
public void testCanConstructWithoutThrowable()
{
QueryException exception = new QueryException(
(Throwable) null,
QueryException.UNKNOWN_EXCEPTION_ERROR_CODE,
"java.lang.Exception",
"test"
);
Assert.assertEquals(QueryException.UNKNOWN_EXCEPTION_ERROR_CODE, exception.getErrorCode());
Assert.assertNull(exception.getMessage());
}
private void expectFailTypeForCode(FailType expected, String code)
{
QueryException exception = new QueryException(new Exception(), code, "java.lang.Exception", "test");

View File

@ -420,14 +420,20 @@ public class CoordinatorPollingBasicAuthorizerCacheManager implements BasicAutho
new BytesFullResponseHandler()
);
final HttpResponseStatus status = responseHolder.getStatus();
// cachedSerializedGroupMappingMap is a new endpoint introduced in Druid 0.17.0. For backwards compatibility, if we
// get a 404 from the coordinator we stop retrying. This can happen during a rolling upgrade when a process
// running 0.17.0+ tries to access this endpoint on an older coordinator.
if (responseHolder.getStatus().equals(HttpResponseStatus.NOT_FOUND)) {
if (HttpResponseStatus.NOT_FOUND.equals(status)) {
LOG.warn("cachedSerializedGroupMappingMap is not available from the coordinator, skipping fetch of group mappings for now.");
return null;
}
if (!HttpResponseStatus.OK.equals(status)) {
LOG.warn("Got an unexpected response status[%s] when loading group mappings.", status);
}
byte[] groupRoleMapBytes = responseHolder.getContent();
GroupMappingAndRoleMap groupMappingAndRoleMap = objectMapper.readValue(

View File

@ -41,7 +41,6 @@ import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import javax.inject.Inject;
import javax.ws.rs.core.MediaType;
import java.io.IOException;
import java.net.URL;
import java.util.Map;
@ -336,7 +335,8 @@ public class DruidClusterClient
*/
public void validate()
{
log.info("Starting cluster validation");
RE exception = new RE("Just building for the stack trace");
log.info(exception, "Starting cluster validation");
for (ResolvedDruidService service : config.requireDruid().values()) {
for (ResolvedInstance instance : service.requireInstances()) {
validateInstance(service, instance);
@ -348,28 +348,46 @@ public class DruidClusterClient
/**
* Validate an instance by waiting for it to report that it is healthy.
*/
@SuppressWarnings("BusyWait")
private void validateInstance(ResolvedDruidService service, ResolvedInstance instance)
{
int timeoutMs = config.readyTimeoutSec() * 1000;
int pollMs = config.readyPollMs();
long startTime = System.currentTimeMillis();
long updateTime = startTime + 5000;
while (System.currentTimeMillis() - startTime < timeoutMs) {
while (true) {
if (isHealthy(service, instance)) {
log.info(
"Service %s, host %s is ready",
"Service[%s], host[%s], tag[%s] is ready",
service.service(),
instance.clientHost());
instance.clientHost(),
instance.tag() == null ? "<default>" : instance.tag()
);
return;
}
long currentTime = System.currentTimeMillis();
if (currentTime > updateTime) {
log.info(
"Service %s, host %s not ready, retrying",
"Service[%s], host[%s], tag[%s] not ready, retrying",
service.service(),
instance.clientHost());
instance.clientHost(),
instance.tag() == null ? "<default>" : instance.tag()
);
updateTime = currentTime + 5000;
}
final long elapsedTime = System.currentTimeMillis() - startTime;
if (elapsedTime > timeoutMs) {
final RE exception = new RE(
"Service[%s], host[%s], tag[%s] not ready after %,d ms.",
service.service(),
instance.clientHost(),
instance.tag() == null ? "<default>" : instance.tag(),
elapsedTime
);
// We log the exception here so that the logs include which thread is having the problem
log.error(exception.getMessage());
throw exception;
}
try {
Thread.sleep(pollMs);
}
@ -377,34 +395,30 @@ public class DruidClusterClient
throw new RuntimeException("Interrupted during cluster validation");
}
}
throw new RE(
StringUtils.format("Service %s, instance %s not ready after %d ms.",
service.service(),
instance.tag() == null ? "<default>" : instance.tag(),
timeoutMs));
}
/**
* Wait for an instance to become ready given the URL and a description of
* the service.
*/
@SuppressWarnings("BusyWait")
public void waitForNodeReady(String label, String url)
{
int timeoutMs = config.readyTimeoutSec() * 1000;
int pollMs = config.readyPollMs();
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < timeoutMs) {
while (true) {
if (isHealthy(url)) {
log.info(
"Service %s, url %s is ready",
label,
url);
log.info("Service[%s], url[%s] is ready", label, url);
return;
}
log.info(
"Service %s, url %s not ready, retrying",
label,
url);
final long elapsedTime = System.currentTimeMillis() - startTime;
if (elapsedTime > timeoutMs) {
final RE re = new RE("Service[%s], url[%s] not ready after %,d ms.", label, url, elapsedTime);
log.error(re.getMessage());
throw re;
}
log.info("Service[%s], url[%s] not ready, retrying", label, url);
try {
Thread.sleep(pollMs);
}
@ -412,11 +426,6 @@ public class DruidClusterClient
throw new RuntimeException("Interrupted while waiting for note to be ready");
}
}
throw new RE(
StringUtils.format("Service %s, url %s not ready after %d ms.",
label,
url,
timeoutMs));
}
public String nodeUrl(DruidNode node)

View File

@ -29,7 +29,7 @@ public class BadJsonQueryException extends BadQueryException
public BadJsonQueryException(JsonParseException e)
{
this(JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS);
this(e, JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS);
}
@JsonCreator
@ -39,6 +39,16 @@ public class BadJsonQueryException extends BadQueryException
@JsonProperty("errorClass") String errorClass
)
{
super(errorCode, errorMessage, errorClass);
this(null, errorCode, errorMessage, errorClass);
}
private BadJsonQueryException(
Throwable cause,
String errorCode,
String errorMessage,
String errorClass
)
{
super(cause, errorCode, errorMessage, errorClass, null);
}
}

View File

@ -30,11 +30,16 @@ public abstract class BadQueryException extends QueryException
protected BadQueryException(String errorCode, String errorMessage, String errorClass)
{
super(errorCode, errorMessage, errorClass, null);
this(errorCode, errorMessage, errorClass, null);
}
protected BadQueryException(String errorCode, String errorMessage, String errorClass, String host)
{
super(errorCode, errorMessage, errorClass, host);
this(null, errorCode, errorMessage, errorClass, host);
}
protected BadQueryException(Throwable cause, String errorCode, String errorMessage, String errorClass, String host)
{
super(cause, errorCode, errorMessage, errorClass, host);
}
}

View File

@ -30,6 +30,7 @@ import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.inject.Inject;
import org.apache.druid.client.DirectDruidClient;
@ -216,21 +217,8 @@ public class QueryResource implements QueryCountStatsProvider
throw new ForbiddenException(authResult.toString());
}
// We use an async context not because we are actually going to run this async, but because we want to delay
// the decision of what the response code should be until we have gotten the first few data points to return.
// Returning a Response object from this point forward requires that object to know the status code, which we
// don't actually know until we are in the accumulator, but if we try to return a Response object from the
// accumulator, we cannot properly stream results back, because the accumulator won't release control of the
// Response until it has consumed the underlying Sequence.
final AsyncContext asyncContext = req.startAsync();
try {
new QueryResourceQueryResultPusher(req, queryLifecycle, io, (HttpServletResponse) asyncContext.getResponse())
.push();
}
finally {
asyncContext.complete();
}
final QueryResourceQueryResultPusher pusher = new QueryResourceQueryResultPusher(req, queryLifecycle, io);
return pusher.push();
}
catch (Exception e) {
if (e instanceof ForbiddenException && !req.isAsyncStarted()) {
@ -258,6 +246,7 @@ public class QueryResource implements QueryCountStatsProvider
out.write(jsonMapper.writeValueAsBytes(responseException));
}
}
return null;
}
finally {
asyncContext.complete();
@ -266,7 +255,6 @@ public class QueryResource implements QueryCountStatsProvider
finally {
Thread.currentThread().setName(currThreadName);
}
return null;
}
public interface QueryMetricCounter
@ -538,18 +526,18 @@ public class QueryResource implements QueryCountStatsProvider
public QueryResourceQueryResultPusher(
HttpServletRequest req,
QueryLifecycle queryLifecycle,
ResourceIOReaderWriter io,
HttpServletResponse response
ResourceIOReaderWriter io
)
{
super(
response,
req,
QueryResource.this.jsonMapper,
QueryResource.this.responseContextConfig,
QueryResource.this.selfNode,
QueryResource.this.counter,
queryLifecycle.getQueryId(),
MediaType.valueOf(io.getResponseWriter().getResponseType())
MediaType.valueOf(io.getResponseWriter().getResponseType()),
ImmutableMap.of()
);
this.req = req;
this.queryLifecycle = queryLifecycle;
@ -561,20 +549,27 @@ public class QueryResource implements QueryCountStatsProvider
{
return new ResultsWriter()
{
private QueryResponse<Object> queryResponse;
@Override
public QueryResponse<Object> start(HttpServletResponse response)
public Response.ResponseBuilder start()
{
final QueryResponse<Object> queryResponse = queryLifecycle.execute();
queryResponse = queryLifecycle.execute();
final ResponseContext responseContext = queryResponse.getResponseContext();
final String prevEtag = getPreviousEtag(req);
if (prevEtag != null && prevEtag.equals(responseContext.getEntityTag())) {
queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1);
counter.incrementSuccess();
response.setStatus(HttpServletResponse.SC_NOT_MODIFIED);
return null;
return Response.status(Status.NOT_MODIFIED);
}
return null;
}
@Override
public QueryResponse<Object> getQueryResponse()
{
return queryResponse;
}

View File

@ -36,45 +36,54 @@ import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.server.security.ForbiddenException;
import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Map;
public abstract class QueryResultPusher
{
private static final Logger log = new Logger(QueryResultPusher.class);
private final HttpServletResponse response;
private final HttpServletRequest request;
private final String queryId;
private final ObjectMapper jsonMapper;
private final ResponseContextConfig responseContextConfig;
private final DruidNode selfNode;
private final QueryResource.QueryMetricCounter counter;
private final MediaType contentType;
private final Map<String, String> extraHeaders;
private StreamingHttpResponseAccumulator accumulator = null;
private AsyncContext asyncContext = null;
private HttpServletResponse response = null;
public QueryResultPusher(
HttpServletResponse response,
HttpServletRequest request,
ObjectMapper jsonMapper,
ResponseContextConfig responseContextConfig,
DruidNode selfNode,
QueryResource.QueryMetricCounter counter,
String queryId,
MediaType contentType
MediaType contentType,
Map<String, String> extraHeaders
)
{
this.response = response;
this.request = request;
this.queryId = queryId;
this.jsonMapper = jsonMapper;
this.responseContextConfig = responseContextConfig;
this.selfNode = selfNode;
this.counter = counter;
this.contentType = contentType;
this.extraHeaders = extraHeaders;
}
/**
@ -92,23 +101,45 @@ public abstract class QueryResultPusher
public abstract void writeException(Exception e, OutputStream out) throws IOException;
public void push()
/**
* Pushes results out. Can sometimes return a JAXRS Response object instead of actually pushing to the output
* stream, primarily for error handling that occurs before switching the servlet to asynchronous mode.
*
* @return null if the response has already been handled and pushed out, or a non-null Response object if it expects
* the container to put the bytes on the wire.
*/
@Nullable
public Response push()
{
response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
ResultsWriter resultsWriter = null;
try {
resultsWriter = start();
final QueryResponse<Object> queryResponse = resultsWriter.start(response);
if (queryResponse == null) {
// It's already been handled...
return;
final Response.ResponseBuilder startResponse = resultsWriter.start();
if (startResponse != null) {
startResponse.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
for (Map.Entry<String, String> entry : extraHeaders.entrySet()) {
startResponse.header(entry.getKey(), entry.getValue());
}
return startResponse.build();
}
final QueryResponse<Object> queryResponse = resultsWriter.getQueryResponse();
final Sequence<Object> results = queryResponse.getResults();
// We use an async context not because we are actually going to run this async, but because we want to delay
// the decision of what the response code should be until we have gotten the first few data points to return.
// Returning a Response object from this point forward requires that object to know the status code, which we
// don't actually know until we are in the accumulator, but if we try to return a Response object from the
// accumulator, we cannot properly stream results back, because the accumulator won't release control of the
// Response until it has consumed the underlying Sequence.
asyncContext = request.startAsync();
response = (HttpServletResponse) asyncContext.getResponse();
response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
for (Map.Entry<String, String> entry : extraHeaders.entrySet()) {
response.setHeader(entry.getKey(), entry.getValue());
}
accumulator = new StreamingHttpResponseAccumulator(queryResponse.getResponseContext(), resultsWriter);
results.accumulate(null, accumulator);
@ -119,8 +150,7 @@ public abstract class QueryResultPusher
resultsWriter.recordSuccess(accumulator.getNumBytesSent());
}
catch (QueryException e) {
handleQueryException(resultsWriter, e);
return;
return handleQueryException(resultsWriter, e);
}
catch (RuntimeException re) {
if (re instanceof ForbiddenException) {
@ -128,17 +158,15 @@ public abstract class QueryResultPusher
// has been committed because the response is committed after results are returned. And, if we started
// returning results before a ForbiddenException gets thrown, that means that we've already leaked stuff
// that should not have been leaked. I.e. it means, we haven't validated the authorization early enough.
if (response.isCommitted()) {
if (response != null && response.isCommitted()) {
log.error(re, "Got a forbidden exception for query[%s] after the response was already committed.", queryId);
}
throw re;
}
handleQueryException(resultsWriter, new QueryInterruptedException(re));
return;
return handleQueryException(resultsWriter, new QueryInterruptedException(re));
}
catch (IOException ioEx) {
handleQueryException(resultsWriter, new QueryInterruptedException(ioEx));
return;
return handleQueryException(resultsWriter, new QueryInterruptedException(ioEx));
}
finally {
if (accumulator != null) {
@ -159,10 +187,15 @@ public abstract class QueryResultPusher
log.warn(e, "Suppressing exception closing accumulator for query[%s]", queryId);
}
}
if (asyncContext != null) {
asyncContext.complete();
}
}
return null;
}
private void handleQueryException(ResultsWriter resultsWriter, QueryException e)
@Nullable
private Response handleQueryException(ResultsWriter resultsWriter, QueryException e)
{
if (accumulator != null && accumulator.isInitialized()) {
// We already started sending a response when we got the error message. In this case we just give up
@ -176,11 +209,7 @@ public abstract class QueryResultPusher
// This case is always a failure because the error happened mid-stream of sending results back. Therefore,
// we do not believe that the response stream was actually useable
counter.incrementFailed();
return;
}
if (response.isCommitted()) {
QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?");
return null;
}
final QueryException.FailType failType = e.getFailType();
@ -206,40 +235,71 @@ public abstract class QueryResultPusher
);
counter.incrementFailed();
}
final int responseStatus = failType.getExpectedStatus();
response.setStatus(responseStatus);
response.setHeader("Content-Type", contentType.toString());
try (ServletOutputStream out = response.getOutputStream()) {
writeException(e, out);
}
catch (IOException ioException) {
log.warn(
ioException,
"Suppressing IOException thrown sending error response for query[%s]",
queryId
);
}
resultsWriter.recordFailure(e);
final int responseStatus = failType.getExpectedStatus();
if (response == null) {
// No response object yet, so assume we haven't started the async context and is safe to return Response
final Response.ResponseBuilder bob = Response
.status(responseStatus)
.type(contentType)
.entity((StreamingOutput) output -> {
writeException(e, output);
output.close();
});
bob.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
for (Map.Entry<String, String> entry : extraHeaders.entrySet()) {
bob.header(entry.getKey(), entry.getValue());
}
return bob.build();
} else {
if (response.isCommitted()) {
QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?");
}
response.setStatus(responseStatus);
response.setHeader("Content-Type", contentType.toString());
try (ServletOutputStream out = response.getOutputStream()) {
writeException(e, out);
}
catch (IOException ioException) {
log.warn(
ioException,
"Suppressing IOException thrown sending error response for query[%s]",
queryId
);
}
return null;
}
}
public interface ResultsWriter extends Closeable
{
/**
* Runs the query and returns a ResultsWriter from running the query.
* Runs the query and prepares the QueryResponse to be returned
* <p>
* This also serves as a hook for any logic that runs on the metadata from a QueryResponse. If this method
* returns {@code null} then the Pusher believes that the response was already handled and skips the rest
* of its logic. As such, any implementation that returns null must make sure that the response has been set
* with a meaningful status, etc.
* returns {@code null} then the Pusher can continue with normal logic. If this method chooses to return
* a ResponseBuilder, then the Pusher will attach any extra metadata it has to the Response and return
* the response built from the Builder without attempting to process the results of the query.
* <p>
* Even if this method returns null, close() should still be called on this object.
* In all cases, {@link #close()} should be called on this object.
*
* @return QueryResponse or null if no more work to do.
*/
@Nullable
QueryResponse<Object> start(HttpServletResponse response);
Response.ResponseBuilder start();
/**
* Gets the results of running the query. {@link #start} must be called before this method is called.
*
* @return the results of running the query as preparted by the {@link #start()} method
*/
QueryResponse<Object> getQueryResponse();
Writer makeWriter(OutputStream out) throws IOException;
@ -301,9 +361,7 @@ public abstract class QueryResultPusher
* Initializes the response. This is done lazily so that we can put various metadata that we only get once
* we have some of the response stream into the result.
* <p>
* This is called once for each result object, but should only actually happen once.
*
* @return boolean if initialization occurred. False most of the team because initialization only happens once.
* It is okay for this to be called multiple times.
*/
public void initialize()
{
@ -332,7 +390,7 @@ public abstract class QueryResultPusher
);
}
catch (JsonProcessingException e) {
QueryResource.log.info(e, "Problem serializing to JSON!?");
log.info(e, "Problem serializing to JSON!?");
serializationResult = new ResponseContext.SerializationResult("Could not serialize", "Could not serialize");
}
@ -343,7 +401,7 @@ public abstract class QueryResultPusher
serializationResult.getFullResult()
);
if (responseContextConfig.shouldFailOnTruncatedResponseContext()) {
QueryResource.log.error(logToPrint);
log.error(logToPrint);
throw new QueryInterruptedException(
new TruncatedResponseContextException(
"Serialized response context exceeds the max size[%s]",
@ -352,12 +410,12 @@ public abstract class QueryResultPusher
selfNode.getHostAndPortToUse()
);
} else {
QueryResource.log.warn(logToPrint);
log.warn(logToPrint);
}
}
response.setHeader(QueryResource.HEADER_RESPONSE_CONTEXT, serializationResult.getResult());
response.setHeader("Content-Type", contentType.toString());
response.setContentType(contentType.toString());
try {
out = new CountingOutputStream(response.getOutputStream());
@ -379,6 +437,7 @@ public abstract class QueryResultPusher
}
@Override
@Nullable
public Response accumulate(Response retVal, Object in)
{
if (!initialized) {

View File

@ -20,11 +20,12 @@
package org.apache.druid.server.security;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.server.DruidNode;
import org.apache.druid.server.QueryResource;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
@ -83,12 +84,15 @@ public class PreResponseAuthorizationCheckFilter implements Filter
filterChain.doFilter(servletRequest, servletResponse);
Boolean authInfoChecked = (Boolean) servletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED);
if (authInfoChecked == null && statusIsSuccess(response.getStatus())) {
if (authInfoChecked == null && statusShouldBeHidden(response.getStatus())) {
// Note: rather than throwing an exception here, it would be nice to blank out the original response
// since the request didn't have any authorization checks performed. However, this breaks proxying
// (e.g. OverlordServletProxy), so this is not implemented for now.
handleAuthorizationCheckError(
"Request did not have an authorization check performed.",
StringUtils.format(
"Request did not have an authorization check performed, original response status[%s].",
response.getStatus()
),
request,
response
);
@ -136,7 +140,6 @@ public class PreResponseAuthorizationCheckFilter implements Filter
OutputStream out = response.getOutputStream();
sendJsonError(response, HttpServletResponse.SC_UNAUTHORIZED, jsonMapper.writeValueAsString(unauthorizedError), out);
out.close();
return;
}
private void handleAuthorizationCheckError(
@ -145,19 +148,21 @@ public class PreResponseAuthorizationCheckFilter implements Filter
HttpServletResponse servletResponse
)
{
final String queryId = servletResponse.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER);
// Send out an alert so there's a centralized collection point for seeing errors of this nature
log.makeAlert(errorMsg)
.addData("uri", servletRequest.getRequestURI())
.addData("method", servletRequest.getMethod())
.addData("remoteAddr", servletRequest.getRemoteAddr())
.addData("remoteHost", servletRequest.getRemoteHost())
.addData("queryId", queryId)
.emit();
if (servletResponse.isCommitted()) {
throw new ISE(errorMsg);
} else {
if (!servletResponse.isCommitted()) {
try {
servletResponse.sendError(HttpServletResponse.SC_FORBIDDEN);
servletResponse.reset();
servletResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
}
catch (Exception e) {
throw new RuntimeException(e);
@ -165,9 +170,23 @@ public class PreResponseAuthorizationCheckFilter implements Filter
}
}
private static boolean statusIsSuccess(int status)
private static boolean statusShouldBeHidden(int status)
{
return 200 <= status && status < 300;
// We allow 404s (not found) to not be rewritten to forbidden because consistently returning 404s is a way to leak
// less information when something wasn't able to be done anyway. I.e. if we pretend that the thing didn't exist
// when the authorization fails, then there is no information about whether the thing existed. If we return
// a 403 when authorization fails and a 404 when authorization succeeds, but it doesn't exist, then we have
// leaked that it could maybe exist, if the authentication credentials were good.
//
// We also allow 307s (temporary redirect) to not be hidden as they are used to redirect to the leader.
switch (status) {
case HttpServletResponse.SC_FORBIDDEN:
case HttpServletResponse.SC_NOT_FOUND:
case HttpServletResponse.SC_TEMPORARY_REDIRECT:
return false;
default:
return true;
}
}
public static void sendJsonError(HttpServletResponse resp, int error, String errorJson, OutputStream outputStream)

View File

@ -26,7 +26,6 @@ import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.inject.Injector;
@ -88,7 +87,9 @@ import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.StreamingOutput;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
@ -396,7 +397,7 @@ public class QueryResourceTest
final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY);
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
//since accept header is null, the response content type should be same as the value of 'Content-Type' header
Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type")));
Assert.assertEquals(MediaType.APPLICATION_JSON, response.getContentType());
}
@Test
@ -409,7 +410,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
//since accept header is empty, the response content type should be same as the value of 'Content-Type' header
Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type")));
Assert.assertEquals(MediaType.APPLICATION_JSON, response.getContentType());
}
@Test
@ -424,10 +425,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
// Content-Type in response should be Smile
Assert.assertEquals(
SmileMediaTypes.APPLICATION_JACKSON_SMILE,
Iterables.getOnlyElement(response.headers.get("Content-Type"))
);
Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType());
}
@Test
@ -447,10 +445,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
// Content-Type in response should be Smile
Assert.assertEquals(
SmileMediaTypes.APPLICATION_JACKSON_SMILE,
Iterables.getOnlyElement(response.headers.get("Content-Type"))
);
Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType());
}
@Test
@ -469,10 +464,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
// Content-Type in response should default to Content-Type from request
Assert.assertEquals(
SmileMediaTypes.APPLICATION_JACKSON_SMILE,
Iterables.getOnlyElement(response.headers.get("Content-Type"))
);
Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType());
}
@Test
@ -643,13 +635,16 @@ public class QueryResourceTest
);
expectPermissiveHappyPathAuth();
final MockHttpServletResponse response = expectAsyncRequestFlow(
final Response response = expectSynchronousRequestFlow(
testServletRequest,
SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8),
timeoutQueryResource
);
Assert.assertEquals(QueryTimeoutException.STATUS_CODE, response.getStatus());
QueryTimeoutException ex = jsonMapper.readValue(response.baos.toByteArray(), QueryTimeoutException.class);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
QueryTimeoutException ex = jsonMapper.readValue(baos.toByteArray(), QueryTimeoutException.class);
Assert.assertEquals("Query Timed Out!", ex.getMessage());
Assert.assertEquals(QueryException.QUERY_TIMEOUT_ERROR_CODE, ex.getErrorCode());
Assert.assertEquals(1, timeoutQueryResource.getTimedOutQueryCount());
@ -892,25 +887,28 @@ public class QueryResourceTest
);
createScheduledQueryResource(laningScheduler, Collections.emptyList(), ImmutableList.of(waitTwoScheduled));
assertResponseAndCountdownOrBlockForever(
assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY,
waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
);
assertResponseAndCountdownOrBlockForever(
assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY,
waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
);
waitTwoScheduled.await();
assertResponseAndCountdownOrBlockForever(
assertSynchronousResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY,
waitAllFinished,
response -> {
Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus());
QueryCapacityExceededException ex;
try {
ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class);
}
catch (IOException e) {
throw new RuntimeException(e);
@ -938,20 +936,22 @@ public class QueryResourceTest
createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled));
assertResponseAndCountdownOrBlockForever(
assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY,
waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
);
waitOneScheduled.await();
assertResponseAndCountdownOrBlockForever(
assertSynchronousResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY,
waitAllFinished,
response -> {
Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus());
QueryCapacityExceededException ex;
try {
ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class);
}
catch (IOException e) {
throw new RuntimeException(e);
@ -965,7 +965,7 @@ public class QueryResourceTest
}
);
waitTwoStarted.await();
assertResponseAndCountdownOrBlockForever(
assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY,
waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
@ -990,20 +990,22 @@ public class QueryResourceTest
createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled));
assertResponseAndCountdownOrBlockForever(
assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY,
waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
);
waitOneScheduled.await();
assertResponseAndCountdownOrBlockForever(
assertSynchronousResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY,
waitAllFinished,
response -> {
Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus());
QueryCapacityExceededException ex;
try {
ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class);
}
catch (IOException e) {
throw new RuntimeException(e);
@ -1016,7 +1018,7 @@ public class QueryResourceTest
}
);
waitTwoStarted.await();
assertResponseAndCountdownOrBlockForever(
assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY_SMALLISH_INTERVAL,
waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
@ -1085,7 +1087,7 @@ public class QueryResourceTest
);
}
private void assertResponseAndCountdownOrBlockForever(
private void assertAsyncResponseAndCountdownOrBlockForever(
String query,
CountDownLatch done,
Consumer<MockHttpServletResponse> asserts
@ -1146,4 +1148,36 @@ public class QueryResourceTest
));
return response;
}
private void assertSynchronousResponseAndCountdownOrBlockForever(
String query,
CountDownLatch done,
Consumer<Response> asserts
)
{
Executors.newSingleThreadExecutor().submit(() -> {
try {
asserts.accept(
expectSynchronousRequestFlow(
testServletRequest.mimic(),
query.getBytes(StandardCharsets.UTF_8),
queryResource
)
);
}
catch (IOException e) {
throw new RuntimeException(e);
}
done.countDown();
});
}
private Response expectSynchronousRequestFlow(
MockHttpServletRequest req,
byte[] bytes,
QueryResource queryResource
) throws IOException
{
return queryResource.doPost(new ByteArrayInputStream(bytes), null, req);
}
}

View File

@ -20,141 +20,155 @@
package org.apache.druid.server.http.security;
import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.server.metrics.NoopServiceEmitter;
import org.apache.druid.server.mocks.MockHttpServletRequest;
import org.apache.druid.server.mocks.MockHttpServletResponse;
import org.apache.druid.server.security.AllowAllAuthenticator;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.Authenticator;
import org.apache.druid.server.security.PreResponseAuthorizationCheckFilter;
import org.easymock.EasyMock;
import org.junit.Rule;
import org.junit.Assert;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import javax.servlet.FilterChain;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.ServletException;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
public class PreResponseAuthorizationCheckFilterTest
{
private static List<Authenticator> authenticators = Collections.singletonList(new AllowAllAuthenticator());
@Rule
public ExpectedException expectedException = ExpectedException.none();
private static final List<Authenticator> authenticators = Collections.singletonList(new AllowAllAuthenticator());
@Test
public void testValidRequest() throws Exception
{
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class);
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class);
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
MockHttpServletRequest req = new MockHttpServletRequest();
MockHttpServletResponse resp = new MockHttpServletResponse();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(true).once();
EasyMock.replay(req, resp, filterChain, outputStream);
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
req.attributes.put(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, filterChain);
EasyMock.verify(req, resp, filterChain, outputStream);
filter.doFilter(req, resp, (request, response) -> {
});
}
@Test
public void testAuthenticationFailedRequest() throws Exception
{
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class);
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class);
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
EasyMock.expect(resp.getOutputStream()).andReturn(outputStream).once();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(null).once();
resp.setStatus(401);
EasyMock.expectLastCall().once();
resp.setContentType("application/json");
EasyMock.expectLastCall().once();
resp.setCharacterEncoding("UTF-8");
EasyMock.expectLastCall().once();
EasyMock.replay(req, resp, filterChain, outputStream);
MockHttpServletRequest req = new MockHttpServletRequest();
MockHttpServletResponse resp = new MockHttpServletResponse();
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, filterChain);
EasyMock.verify(req, resp, filterChain, outputStream);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(401, resp.getStatus());
Assert.assertEquals("application/json", resp.getContentType());
Assert.assertEquals("UTF-8", resp.getCharacterEncoding());
}
@Test
public void testMissingAuthorizationCheck() throws Exception
public void testMissingAuthorizationCheckAndNotCommitted() throws ServletException, IOException
{
EmittingLogger.registerEmitter(EasyMock.createNiceMock(ServiceEmitter.class));
expectedException.expect(ISE.class);
expectedException.expectMessage("Request did not have an authorization check performed.");
EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class);
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class);
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
MockHttpServletRequest req = new MockHttpServletRequest();
req.requestUri = "uri";
req.method = "GET";
req.remoteAddr = "1.2.3.4";
req.remoteHost = "aHost";
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(null).once();
EasyMock.expect(resp.getStatus()).andReturn(200).once();
EasyMock.expect(req.getRequestURI()).andReturn("uri").once();
EasyMock.expect(req.getMethod()).andReturn("GET").once();
EasyMock.expect(req.getRemoteAddr()).andReturn("1.2.3.4").once();
EasyMock.expect(req.getRemoteHost()).andReturn("ahostname").once();
EasyMock.expect(resp.isCommitted()).andReturn(true).once();
MockHttpServletResponse resp = new MockHttpServletResponse();
resp.setStatus(200);
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(403, resp.getStatus());
}
@Test
public void testMissingAuthorizationCheckWithForbidden() throws Exception
{
EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
MockHttpServletResponse resp = new MockHttpServletResponse();
resp.setStatus(403);
EasyMock.expectLastCall().once();
resp.setContentType("application/json");
EasyMock.expectLastCall().once();
resp.setCharacterEncoding("UTF-8");
EasyMock.expectLastCall().once();
EasyMock.replay(req, resp, filterChain, outputStream);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, filterChain);
EasyMock.verify(req, resp, filterChain, outputStream);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(403, resp.getStatus());
}
@Test
public void testMissingAuthorizationCheckWithError() throws Exception
public void testMissingAuthorizationCheckWith404Keeps404() throws Exception
{
EmittingLogger.registerEmitter(EasyMock.createNiceMock(ServiceEmitter.class));
EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class);
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class);
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(null).once();
EasyMock.expect(resp.getStatus()).andReturn(404).once();
EasyMock.replay(req, resp, filterChain, outputStream);
MockHttpServletResponse resp = new MockHttpServletResponse();
resp.setStatus(404);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, filterChain);
EasyMock.verify(req, resp, filterChain, outputStream);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(404, resp.getStatus());
}
@Test
public void testMissingAuthorizationCheckWith307Keeps307() throws Exception
{
EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
MockHttpServletResponse resp = new MockHttpServletResponse();
resp.setStatus(307);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(307, resp.getStatus());
}
}

View File

@ -54,9 +54,11 @@ import java.util.function.Supplier;
*/
public class MockHttpServletRequest implements HttpServletRequest
{
public String requestUri = null;
public String method = null;
public String contentType = null;
public String remoteAddr = null;
public String remoteHost = null;
public LinkedHashMap<String, String> headers = new LinkedHashMap<>();
public LinkedHashMap<String, Object> attributes = new LinkedHashMap<>();
@ -110,7 +112,7 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override
public String getMethod()
{
return method;
return unsupportedIfNull(method);
}
@Override
@ -164,7 +166,7 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override
public String getRequestURI()
{
throw new UnsupportedOperationException();
return unsupportedIfNull(requestUri);
}
@Override
@ -296,7 +298,7 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override
public String getContentType()
{
return contentType;
return unsupportedIfNull(contentType);
}
@Override
@ -362,13 +364,13 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override
public String getRemoteAddr()
{
return remoteAddr;
return unsupportedIfNull(remoteAddr);
}
@Override
public String getRemoteHost()
{
throw new UnsupportedOperationException();
return unsupportedIfNull(remoteHost);
}
@Override
@ -486,6 +488,9 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override
public AsyncContext getAsyncContext()
{
if (currAsyncContext == null) {
throw new IllegalStateException("Must be put into Async mode before async context can be gottendid");
}
return currAsyncContext;
}
@ -514,4 +519,13 @@ public class MockHttpServletRequest implements HttpServletRequest
return retVal;
}
private <T> T unsupportedIfNull(T obj)
{
if (obj == null) {
throw new UnsupportedOperationException();
} else {
return obj;
}
}
}

View File

@ -22,12 +22,12 @@ package org.apache.druid.server.mocks;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import javax.validation.constraints.NotNull;
import java.io.ByteArrayOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
@ -63,6 +63,21 @@ public class MockHttpServletResponse implements HttpServletResponse
private int statusCode;
private String contentType;
private String characterEncoding;
@Override
public void reset()
{
if (isCommitted()) {
throw new IllegalStateException("Cannot reset a committed ServletResponse");
}
headers.clear();
statusCode = 0;
contentType = null;
characterEncoding = null;
}
@Override
public void addCookie(Cookie cookie)
@ -198,7 +213,7 @@ public class MockHttpServletResponse implements HttpServletResponse
@Override
public String getCharacterEncoding()
{
throw new UnsupportedOperationException();
return characterEncoding;
}
@Override
@ -231,13 +246,13 @@ public class MockHttpServletResponse implements HttpServletResponse
}
@Override
public void write(@Nonnull byte[] b)
public void write(@NotNull byte[] b)
{
baos.write(b, 0, b.length);
}
@Override
public void write(@Nonnull byte[] b, int off, int len)
public void write(@NotNull byte[] b, int off, int len)
{
baos.write(b, off, len);
}
@ -253,7 +268,7 @@ public class MockHttpServletResponse implements HttpServletResponse
@Override
public void setCharacterEncoding(String charset)
{
throw new UnsupportedOperationException();
characterEncoding = charset;
}
@Override
@ -298,18 +313,19 @@ public class MockHttpServletResponse implements HttpServletResponse
throw new UnsupportedOperationException();
}
public void forceCommitted()
{
if (!isCommitted()) {
baos.write(1234);
}
}
@Override
public boolean isCommitted()
{
return baos.size() > 0;
}
@Override
public void reset()
{
throw new UnsupportedOperationException();
}
@Override
public void setLocale(Locale loc)
{

View File

@ -61,22 +61,27 @@ public class SqlPlanningException extends BadQueryException
public SqlPlanningException(SqlParseException e)
{
this(PlanningError.SQL_PARSE_ERROR, e.getMessage());
this(e, PlanningError.SQL_PARSE_ERROR, e.getMessage());
}
public SqlPlanningException(ValidationException e)
{
this(PlanningError.VALIDATION_ERROR, e.getMessage());
this(e, PlanningError.VALIDATION_ERROR, e.getMessage());
}
public SqlPlanningException(CalciteContextException e)
{
this(PlanningError.VALIDATION_ERROR, e.getMessage());
this(e, PlanningError.VALIDATION_ERROR, e.getMessage());
}
public SqlPlanningException(PlanningError planningError, String errorMessage)
{
this(planningError.errorCode, errorMessage, planningError.errorClass);
this(null, planningError, errorMessage);
}
public SqlPlanningException(Throwable cause, PlanningError planningError, String errorMessage)
{
this(cause, planningError.errorCode, errorMessage, planningError.errorClass);
}
@JsonCreator
@ -86,6 +91,17 @@ public class SqlPlanningException extends BadQueryException
@JsonProperty("errorClass") String errorClass
)
{
super(errorCode, errorMessage, errorClass);
this(null, errorCode, errorMessage, errorClass);
}
private SqlPlanningException(
Throwable cause,
String errorCode,
String errorMessage,
String errorClass
)
{
super(cause, errorCode, errorMessage, errorClass, null);
}
}

View File

@ -48,9 +48,7 @@ import org.apache.druid.sql.SqlRowTransformer;
import org.apache.druid.sql.SqlStatementFactory;
import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.POST;
@ -63,7 +61,9 @@ import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import java.io.IOException;
import java.io.OutputStream;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@ -121,22 +121,8 @@ public class SqlResource
try {
Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId));
// We use an async context not because we are actually going to run this async, but because we want to delay
// the decision of what the response code should be until we have gotten the first few data points to return.
// Returning a Response object from this point forward requires that object to know the status code, which we
// don't actually know until we are in the accumulator, but if we try to return a Response object from the
// accumulator, we cannot properly stream results back, because the accumulator won't release control of the
// Response until it has consumed the underlying Sequence.
final AsyncContext asyncContext = req.startAsync();
try {
QueryResultPusher pusher = new SqlResourceQueryResultPusher(asyncContext, sqlQueryId, stmt, sqlQuery);
pusher.push();
return null;
}
finally {
asyncContext.complete();
}
QueryResultPusher pusher = makePusher(req, stmt, sqlQuery);
return pusher.push();
}
finally {
Thread.currentThread().setName(currThreadName);
@ -213,27 +199,43 @@ public class SqlResource
}
}
private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStatement stmt, SqlQuery sqlQuery)
{
final String sqlQueryId = stmt.sqlQueryId();
Map<String, String> headers = new LinkedHashMap<>();
headers.put(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId);
if (sqlQuery.includeHeader()) {
headers.put(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE);
}
return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, headers);
}
private class SqlResourceQueryResultPusher extends QueryResultPusher
{
private final String sqlQueryId;
private final HttpStatement stmt;
private final SqlQuery sqlQuery;
public SqlResourceQueryResultPusher(
AsyncContext asyncContext,
HttpServletRequest req,
String sqlQueryId,
HttpStatement stmt,
SqlQuery sqlQuery
SqlQuery sqlQuery,
Map<String, String> headers
)
{
super(
(HttpServletResponse) asyncContext.getResponse(),
req,
SqlResource.this.jsonMapper,
SqlResource.this.responseContextConfig,
SqlResource.this.selfNode,
SqlResource.QUERY_METRIC_COUNTER,
sqlQueryId,
MediaType.APPLICATION_JSON_TYPE
MediaType.APPLICATION_JSON_TYPE,
headers
);
this.sqlQueryId = sqlQueryId;
this.stmt = stmt;
@ -245,19 +247,17 @@ public class SqlResource
{
return new ResultsWriter()
{
private QueryResponse<Object[]> queryResponse;
private ResultSet thePlan;
@Override
@Nullable
@SuppressWarnings({"unchecked", "rawtypes"})
public QueryResponse<Object> start(HttpServletResponse response)
public Response.ResponseBuilder start()
{
response.setHeader(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId);
final QueryResponse<Object[]> retVal;
try {
thePlan = stmt.plan();
retVal = thePlan.run();
queryResponse = thePlan.run();
return null;
}
catch (RelOptPlanner.CannotPlanException e) {
throw new SqlPlanningException(
@ -276,12 +276,13 @@ public class SqlResource
// doesn't implement org.apache.druid.common.exception.SanitizableException.
throw new QueryInterruptedException(e);
}
}
if (sqlQuery.includeHeader()) {
response.setHeader(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE);
}
return (QueryResponse) retVal;
@Override
@SuppressWarnings({"unchecked", "rawtypes"})
public QueryResponse<Object> getQueryResponse()
{
return (QueryResponse) queryResponse;
}
@Override
@ -343,6 +344,11 @@ public class SqlResource
@Override
public void recordFailure(Exception e)
{
if (sqlQuery.queryContext().isDebug()) {
log.warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId);
} else {
log.noStackTrace().warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId);
}
stmt.reporter().failed(e);
}

View File

@ -22,7 +22,6 @@ package org.apache.druid.sql.http;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
@ -39,6 +38,7 @@ import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.guava.LazySequence;
@ -61,6 +61,7 @@ import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.query.groupby.GroupByQueryConfig;
import org.apache.druid.server.DruidNode;
import org.apache.druid.server.QueryResource;
import org.apache.druid.server.QueryResponse;
import org.apache.druid.server.QueryScheduler;
import org.apache.druid.server.QueryStackTests;
@ -109,9 +110,14 @@ import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.StreamingOutput;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.AbstractList;
import java.util.ArrayList;
@ -128,6 +134,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
public class SqlResourceTest extends CalciteTestBase
@ -175,7 +182,7 @@ public class SqlResourceTest extends CalciteTestBase
private final SettableSupplier<ResponseContext> responseContextSupplier = new SettableSupplier<>();
private Consumer<DirectStatement> onExecute = NULL_ACTION;
private boolean sleep;
private Supplier<Void> schedulerBaggage = () -> null;
@Before
public void setUp() throws Exception
@ -195,15 +202,8 @@ public class SqlResourceTest extends CalciteTestBase
{
return super.run(
query,
new LazySequence<T>(() -> {
if (sleep) {
try {
// pretend to be a query that is waiting on results
Thread.sleep(500);
}
catch (InterruptedException ignored) {
}
}
new LazySequence<>(() -> {
schedulerBaggage.get();
return resultSequence;
})
);
@ -329,7 +329,7 @@ public class SqlResourceTest extends CalciteTestBase
public void testUnauthorized()
{
try {
postForResponse(
postForAsyncResponse(
createSimpleQueryWithId("id", "select count(*) from forbiddenDatasource"),
request()
);
@ -380,18 +380,17 @@ public class SqlResourceTest extends CalciteTestBase
mockRespContext.put(ResponseContext.Keys.instance().keyOf("uncoveredIntervalsOverflowed"), "true");
responseContextSupplier.set(mockRespContext);
final MockHttpServletResponse response = postForResponse(sqlQuery, makeRegularUserReq());
final MockHttpServletResponse response = postForAsyncResponse(sqlQuery, makeRegularUserReq());
Map responseContext = JSON_MAPPER.readValue(
Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")),
Map.class
);
Assert.assertEquals(
ImmutableMap.of(
"uncoveredIntervals", "2030-01-01/78149827981274-01-01",
"uncoveredIntervalsOverflowed", "true"
),
responseContext
JSON_MAPPER.readValue(
Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")),
Map.class
)
);
Object results = JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class);
@ -714,6 +713,7 @@ public class SqlResourceTest extends CalciteTestBase
}
@Test
@SuppressWarnings("rawtypes")
public void testArrayResultFormatWithHeader() throws Exception
{
final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2";
@ -725,7 +725,7 @@ public class SqlResourceTest extends CalciteTestBase
Arrays.asList("2000-01-02T00:00:00.000Z", "10.1", nullStr, "[\"b\",\"c\"]", 1, 2.0, 2.0, hllStr, nullStr)
};
MockHttpServletResponse response = postForResponse(
MockHttpServletResponse response = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null),
req.mimic()
);
@ -745,7 +745,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class)
);
MockHttpServletResponse responseNoSqlTypesHeader = postForResponse(
MockHttpServletResponse responseNoSqlTypesHeader = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, true, false, null, null),
req.mimic()
);
@ -764,7 +764,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(responseNoSqlTypesHeader.baos.toByteArray(), Object.class)
);
MockHttpServletResponse responseNoTypesHeader = postForResponse(
MockHttpServletResponse responseNoTypesHeader = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, false, true, null, null),
req.mimic()
);
@ -783,7 +783,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(responseNoTypesHeader.baos.toByteArray(), Object.class)
);
MockHttpServletResponse responseNoTypes = postForResponse(
MockHttpServletResponse responseNoTypes = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, false, false, null, null),
req.mimic()
);
@ -801,7 +801,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(responseNoTypes.baos.toByteArray(), Object.class)
);
MockHttpServletResponse responseNoHeader = postForResponse(
MockHttpServletResponse responseNoHeader = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, false, false, false, null, null),
req.mimic()
);
@ -821,7 +821,7 @@ public class SqlResourceTest extends CalciteTestBase
// Test a query that returns null header for some of the columns
final String query = "SELECT (1, 2) FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1";
MockHttpServletResponse response = postForResponse(
MockHttpServletResponse response = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null),
req
);
@ -971,12 +971,10 @@ public class SqlResourceTest extends CalciteTestBase
{
final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2";
final String nullStr = NullHandling.replaceWithDefault() ? "" : null;
final Function<Map<String, Object>, Map<String, Object>> transformer = m -> {
return Maps.transformEntries(
m,
(k, v) -> "EXPR$8".equals(k) || ("dim2".equals(k) && v.toString().isEmpty()) ? nullStr : v
);
};
final Function<Map<String, Object>, Map<String, Object>> transformer = m -> Maps.transformEntries(
m,
(k, v) -> "EXPR$8".equals(k) || ("dim2".equals(k) && v.toString().isEmpty()) ? nullStr : v
);
Assert.assertEquals(
ImmutableList.of(
@ -1336,9 +1334,7 @@ public class SqlResourceTest extends CalciteTestBase
@Test
public void testCannotParse() throws Exception
{
final QueryException exception = doPost(
createSimpleQueryWithId("id", "FROM druid.foo")
).lhs;
QueryException exception = postSyncForException("FROM druid.foo", Status.BAD_REQUEST.getStatusCode());
Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.SQL_PARSE_ERROR.getErrorCode(), exception.getErrorCode());
@ -1351,9 +1347,7 @@ public class SqlResourceTest extends CalciteTestBase
@Test
public void testCannotValidate() throws Exception
{
final QueryException exception = doPost(
createSimpleQueryWithId("id", "SELECT dim4 FROM druid.foo")
).lhs;
QueryException exception = postSyncForException("SELECT dim4 FROM druid.foo", Status.BAD_REQUEST.getStatusCode());
Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode());
@ -1367,10 +1361,10 @@ public class SqlResourceTest extends CalciteTestBase
public void testCannotConvert() throws Exception
{
// SELECT + ORDER unsupported
final QueryException exception = doPost(
createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1")
).lhs;
final SqlQuery unsupportedQuery = createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1");
QueryException exception = postSyncForException(unsupportedQuery, Status.BAD_REQUEST.getStatusCode());
Assert.assertTrue((Boolean) req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED));
Assert.assertNotNull(exception);
Assert.assertEquals("SQL query is unsupported", exception.getErrorCode());
Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorClass(), exception.getErrorClass());
@ -1392,9 +1386,10 @@ public class SqlResourceTest extends CalciteTestBase
public void testCannotConvert_UnsupportedSQLQueryException() throws Exception
{
// max(string) unsupported
final QueryException exception = doPost(
createSimpleQueryWithId("id", "SELECT max(dim1) FROM druid.foo")
).lhs;
QueryException exception = postSyncForException(
"SELECT max(dim1) FROM druid.foo",
Status.BAD_REQUEST.getStatusCode()
);
Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorCode(), exception.getErrorCode());
@ -1442,7 +1437,7 @@ public class SqlResourceTest extends CalciteTestBase
{
String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage);
final QueryException exception = doPost(
QueryException exception = postSyncForException(
new SqlQuery(
"SELECT ANSWER TO LIFE",
ResultFormat.OBJECT,
@ -1451,8 +1446,9 @@ public class SqlResourceTest extends CalciteTestBase
false,
ImmutableMap.of(BaseQuery.SQL_QUERY_ID, "id"),
null
)
).lhs;
),
501
);
Assert.assertNotNull(exception);
Assert.assertEquals(QueryException.QUERY_UNSUPPORTED_ERROR_CODE, exception.getErrorCode());
@ -1466,7 +1462,7 @@ public class SqlResourceTest extends CalciteTestBase
String queryId = "id123";
String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage);
final MockHttpServletResponse response = postForResponse(
final Response response = postForSyncResponse(
new SqlQuery(
"SELECT ANSWER TO LIFE",
ResultFormat.OBJECT,
@ -1478,11 +1474,12 @@ public class SqlResourceTest extends CalciteTestBase
),
req
);
Assert.assertNotEquals(200, response.getStatus());
Assert.assertEquals(
queryId,
Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER))
);
// This is checked in the common method that returns the response, but checking it again just protects
// from changes there breaking the checks, so doesn't hurt.
assertStatusAndCommonHeaders(response, 501);
Assert.assertEquals(queryId, getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER));
Assert.assertEquals(queryId, getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER));
}
@Test
@ -1490,7 +1487,7 @@ public class SqlResourceTest extends CalciteTestBase
{
String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage);
final MockHttpServletResponse response = postForResponse(
final Response response = postForSyncResponse(
new SqlQuery(
"SELECT ANSWER TO LIFE",
ResultFormat.OBJECT,
@ -1502,10 +1499,10 @@ public class SqlResourceTest extends CalciteTestBase
),
req
);
Assert.assertNotEquals(200, response.getStatus());
Assert.assertFalse(
Strings.isNullOrEmpty(Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)))
);
// This is checked in the common method that returns the response, but checking it again just protects
// from changes there breaking the checks, so doesn't hurt.
assertStatusAndCommonHeaders(response, 501);
}
@Test
@ -1536,7 +1533,7 @@ public class SqlResourceTest extends CalciteTestBase
String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage);
final QueryException exception = doPost(
QueryException exception = postSyncForException(
new SqlQuery(
"SELECT ANSWER TO LIFE",
ResultFormat.OBJECT,
@ -1545,8 +1542,9 @@ public class SqlResourceTest extends CalciteTestBase
false,
ImmutableMap.of("sqlQueryId", "id"),
null
)
).lhs;
),
501
);
Assert.assertNotNull(exception);
Assert.assertNull(exception.getMessage());
@ -1587,7 +1585,7 @@ public class SqlResourceTest extends CalciteTestBase
onExecute = s -> {
throw new AssertionError(errorMessage);
};
final QueryException exception = doPost(
QueryException exception = postSyncForException(
new SqlQuery(
"SELECT ANSWER TO LIFE",
ResultFormat.OBJECT,
@ -1596,8 +1594,9 @@ public class SqlResourceTest extends CalciteTestBase
false,
ImmutableMap.of("sqlQueryId", "id"),
null
)
).lhs;
),
Status.INTERNAL_SERVER_ERROR.getStatusCode()
);
Assert.assertNotNull(exception);
Assert.assertNull(exception.getMessage());
@ -1610,15 +1609,28 @@ public class SqlResourceTest extends CalciteTestBase
@Test
public void testTooManyRequests() throws Exception
{
sleep = true;
final int numQueries = 3;
CountDownLatch queriesScheduledLatch = new CountDownLatch(numQueries - 1);
CountDownLatch runQueryLatch = new CountDownLatch(1);
schedulerBaggage = () -> {
queriesScheduledLatch.countDown();
try {
runQueryLatch.await();
}
catch (InterruptedException e) {
throw new RE(e);
}
return null;
};
final String sqlQueryId = "tooManyRequestsTest";
List<Future<Pair<QueryException, List<Map<String, Object>>>>> futures = new ArrayList<>(numQueries);
for (int i = 0; i < numQueries; i++) {
List<Future<Object>> futures = new ArrayList<>(numQueries);
for (int i = 0; i < numQueries - 1; i++) {
futures.add(executorService.submit(() -> {
try {
return doPost(
return postForAsyncResponse(
new SqlQuery(
"SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo",
null,
@ -1637,22 +1649,51 @@ public class SqlResourceTest extends CalciteTestBase
}));
}
queriesScheduledLatch.await();
schedulerBaggage = () -> null;
futures.add(executorService.submit(() -> {
try {
final Response retVal = postForSyncResponse(
new SqlQuery(
"SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo",
null,
false,
false,
false,
ImmutableMap.of("priority", -5, BaseQuery.SQL_QUERY_ID, sqlQueryId),
null
),
makeRegularUserReq()
);
runQueryLatch.countDown();
return retVal;
}
catch (Exception e) {
throw new RuntimeException(e);
}
}));
int success = 0;
int limited = 0;
for (int i = 0; i < numQueries; i++) {
Pair<QueryException, List<Map<String, Object>>> result = futures.get(i).get();
List<Map<String, Object>> rows = result.rhs;
if (rows != null) {
Assert.assertEquals(ImmutableList.of(ImmutableMap.of("cnt", 6, "TheFoo", "foo")), rows);
success++;
} else {
QueryException interruped = result.lhs;
if (i == 2) {
Response response = (Response) futures.get(i).get();
assertStatusAndCommonHeaders(response, 429);
QueryException interruped = deserializeResponse(response, QueryException.class);
Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, interruped.getErrorCode());
Assert.assertEquals(
QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 2),
interruped.getMessage()
);
limited++;
} else {
MockHttpServletResponse response = (MockHttpServletResponse) futures.get(i).get();
assertStatusAndCommonHeaders(response, 200);
Assert.assertEquals(
ImmutableList.of(ImmutableMap.of("cnt", 6, "TheFoo", "foo")),
deserializeResponse(response, Object.class)
);
success++;
}
}
Assert.assertEquals(2, success);
@ -1671,7 +1712,8 @@ public class SqlResourceTest extends CalciteTestBase
BaseQuery.SQL_QUERY_ID,
sqlQueryId
);
final QueryException timeoutException = doPost(
QueryException exception = postSyncForException(
new SqlQuery(
"SELECT CAST(__time AS DATE), dim1, dim2, dim3 FROM druid.foo GROUP by __time, dim1, dim2, dim3 ORDER BY dim2 DESC",
ResultFormat.OBJECT,
@ -1680,11 +1722,13 @@ public class SqlResourceTest extends CalciteTestBase
false,
queryContext,
null
)
).lhs;
Assert.assertNotNull(timeoutException);
Assert.assertEquals(timeoutException.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE);
Assert.assertEquals(timeoutException.getErrorClass(), QueryTimeoutException.class.getName());
),
504
);
Assert.assertNotNull(exception);
Assert.assertEquals(exception.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE);
Assert.assertEquals(exception.getErrorClass(), QueryTimeoutException.class.getName());
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
}
@ -1697,8 +1741,8 @@ public class SqlResourceTest extends CalciteTestBase
validateAndAuthorizeLatchSupplier.set(new NonnullPair<>(validateAndAuthorizeLatch, true));
CountDownLatch planLatch = new CountDownLatch(1);
planLatchSupplier.set(new NonnullPair<>(planLatch, false));
Future<MockHttpServletResponse> future = executorService.submit(
() -> postForResponse(
Future<Response> future = executorService.submit(
() -> postForSyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"),
makeRegularUserReq()
)
@ -1711,9 +1755,10 @@ public class SqlResourceTest extends CalciteTestBase
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
MockHttpServletResponse queryResponse = future.get();
Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus());
QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class);
Response queryResponse = future.get();
assertStatusAndCommonHeaders(queryResponse, Status.INTERNAL_SERVER_ERROR.getStatusCode());
QueryException exception = deserializeResponse(queryResponse, QueryException.class);
Assert.assertEquals("Query cancelled", exception.getErrorCode());
}
@ -1725,8 +1770,8 @@ public class SqlResourceTest extends CalciteTestBase
planLatchSupplier.set(new NonnullPair<>(planLatch, true));
CountDownLatch execLatch = new CountDownLatch(1);
executeLatchSupplier.set(new NonnullPair<>(execLatch, false));
Future<MockHttpServletResponse> future = executorService.submit(
() -> postForResponse(
Future<Response> future = executorService.submit(
() -> postForSyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"),
makeRegularUserReq()
)
@ -1738,9 +1783,10 @@ public class SqlResourceTest extends CalciteTestBase
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
MockHttpServletResponse queryResponse = future.get();
Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus());
QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class);
Response queryResponse = future.get();
assertStatusAndCommonHeaders(queryResponse, Status.INTERNAL_SERVER_ERROR.getStatusCode());
QueryException exception = deserializeResponse(queryResponse, QueryException.class);
Assert.assertEquals("Query cancelled", exception.getErrorCode());
}
@ -1753,7 +1799,7 @@ public class SqlResourceTest extends CalciteTestBase
CountDownLatch execLatch = new CountDownLatch(1);
executeLatchSupplier.set(new NonnullPair<>(execLatch, false));
Future<MockHttpServletResponse> future = executorService.submit(
() -> postForResponse(
() -> postForAsyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"),
makeRegularUserReq()
)
@ -1778,7 +1824,7 @@ public class SqlResourceTest extends CalciteTestBase
CountDownLatch execLatch = new CountDownLatch(1);
executeLatchSupplier.set(new NonnullPair<>(execLatch, false));
Future<MockHttpServletResponse> future = executorService.submit(
() -> postForResponse(
() -> postForAsyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM forbiddenDatasource"),
makeSuperUserReq()
)
@ -1827,23 +1873,16 @@ public class SqlResourceTest extends CalciteTestBase
public void testQueryContextKeyNotAllowed() throws Exception
{
Map<String, Object> queryContext = ImmutableMap.of(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, "all");
final QueryException queryContextException = doPost(
new SqlQuery(
"SELECT 1337",
ResultFormat.OBJECT,
false,
false,
false,
queryContext,
null
)
).lhs;
Assert.assertNotNull(queryContextException);
Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), queryContextException.getErrorCode());
QueryException exception = postSyncForException(
new SqlQuery("SELECT 1337", ResultFormat.OBJECT, false, false, false, queryContext, null),
Status.BAD_REQUEST.getStatusCode()
);
Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode());
MatcherAssert.assertThat(
queryContextException.getMessage(),
CoreMatchers.containsString(
"Cannot execute query with context parameter [sqlInsertSegmentGranularity]")
exception.getMessage(),
CoreMatchers.containsString("Cannot execute query with context parameter [sqlInsertSegmentGranularity]")
);
checkSqlRequestLog(false);
}
@ -1922,7 +1961,7 @@ public class SqlResourceTest extends CalciteTestBase
private Pair<QueryException, String> doPostRaw(final SqlQuery query, final MockHttpServletRequest req)
throws Exception
{
MockHttpServletResponse response = postForResponse(query, req);
MockHttpServletResponse response = postForAsyncResponse(query, req);
if (response.getStatus() == 200) {
return Pair.of(null, new String(response.baos.toByteArray(), StandardCharsets.UTF_8));
@ -1932,14 +1971,122 @@ public class SqlResourceTest extends CalciteTestBase
}
@Nonnull
private MockHttpServletResponse postForResponse(SqlQuery query, MockHttpServletRequest req)
private MockHttpServletResponse postForAsyncResponse(SqlQuery query, MockHttpServletRequest req)
{
MockHttpServletResponse response = MockHttpServletResponse.forRequest(req);
final Object explicitQueryId = query.getContext().get("queryId");
final Object explicitSqlQueryId = query.getContext().get("sqlQueryId");
Assert.assertNull(resource.doPost(query, req));
final Object actualQueryId = response.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER);
final Object actualSqlQueryId = response.getHeader(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER);
validateQueryIds(explicitQueryId, explicitSqlQueryId, actualQueryId, actualSqlQueryId);
return response;
}
private void assertStatusAndCommonHeaders(MockHttpServletResponse queryResponse, int statusCode)
{
Assert.assertEquals(statusCode, queryResponse.getStatus());
Assert.assertEquals("application/json", queryResponse.getContentType());
Assert.assertNotNull(queryResponse.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER));
Assert.assertNotNull(queryResponse.getHeader(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER));
}
private <T> T deserializeResponse(MockHttpServletResponse resp, Class<T> clazz) throws IOException
{
return JSON_MAPPER.readValue(resp.baos.toByteArray(), clazz);
}
private Response postForSyncResponse(SqlQuery query, MockHttpServletRequest req)
{
final Object explicitQueryId = query.getContext().get("queryId");
final Object explicitSqlQueryId = query.getContext().get("sqlQueryId");
final Response response = resource.doPost(query, req);
final Object actualQueryId = getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER);
final Object actualSqlQueryId = getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER);
validateQueryIds(explicitQueryId, explicitSqlQueryId, actualQueryId, actualSqlQueryId);
return response;
}
private QueryException postSyncForException(String s, int expectedStatus) throws IOException
{
return postSyncForException(createSimpleQueryWithId("id", s), expectedStatus);
}
private QueryException postSyncForException(SqlQuery query, int expectedStatus) throws IOException
{
final Response response = postForSyncResponse(query, req);
assertStatusAndCommonHeaders(response, expectedStatus);
return deserializeResponse(response, QueryException.class);
}
private <T> T deserializeResponse(Response resp, Class<T> clazz) throws IOException
{
return JSON_MAPPER.readValue(responseToByteArray(resp), clazz);
}
private byte[] responseToByteArray(Response resp) throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) resp.getEntity()).write(baos);
return baos.toByteArray();
}
private String getContentType(Response resp)
{
return getHeader(resp, HttpHeaders.CONTENT_TYPE).toString();
}
@Nullable
private Object getHeader(Response resp, String header)
{
final List<Object> objects = resp.getMetadata().get(header);
if (objects == null) {
return null;
}
return Iterables.getOnlyElement(objects);
}
private void assertStatusAndCommonHeaders(Response queryResponse, int statusCode)
{
Assert.assertEquals(statusCode, queryResponse.getStatus());
Assert.assertEquals("application/json", getContentType(queryResponse));
Assert.assertNotNull(getHeader(queryResponse, QueryResource.QUERY_ID_RESPONSE_HEADER));
Assert.assertNotNull(getHeader(queryResponse, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER));
}
private void validateQueryIds(
Object explicitQueryId,
Object explicitSqlQueryId,
Object actualQueryId,
Object actualSqlQueryId
)
{
if (explicitQueryId == null) {
if (null != explicitSqlQueryId) {
Assert.assertEquals(explicitSqlQueryId, actualQueryId);
Assert.assertEquals(explicitSqlQueryId, actualSqlQueryId);
} else {
Assert.assertNotNull(actualQueryId);
Assert.assertNotNull(actualSqlQueryId);
}
} else {
if (explicitSqlQueryId == null) {
Assert.assertEquals(explicitQueryId, actualQueryId);
Assert.assertEquals(explicitQueryId, actualSqlQueryId);
} else {
Assert.assertEquals(explicitQueryId, actualQueryId);
Assert.assertEquals(explicitSqlQueryId, actualSqlQueryId);
}
}
}
private MockHttpServletRequest makeSuperUserReq()
{
return makeExpectedReq(CalciteTests.SUPER_USER_AUTH_RESULT);
@ -1954,6 +2101,7 @@ public class SqlResourceTest extends CalciteTestBase
{
MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
req.remoteAddr = "1.2.3.4";
return req;
}