Validate response headers and fix exception logging (#13609)

* Validate response headers and fix exception logging

A class of QueryException were throwing away their
causes making it really hard to determine what's
going wrong when something goes wrong in the SQL
planner specifically.  Fix that and adjust tests
 to do more validation of response headers as well.

We allow 404s and 307s to be returned even without 
authorization validated, but others get converted to 403
This commit is contained in:
imply-cheddar 2023-01-06 07:15:15 +09:00 committed by GitHub
parent 7a7874a952
commit a8ecc48ffe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 760 additions and 386 deletions

View File

@ -152,7 +152,12 @@ public class QueryException extends RuntimeException implements SanitizableExcep
protected QueryException(Throwable cause, String errorCode, String errorClass, String host) protected QueryException(Throwable cause, String errorCode, String errorClass, String host)
{ {
super(cause == null ? null : cause.getMessage(), cause); this(cause, errorCode, cause == null ? null : cause.getMessage(), errorClass, host);
}
protected QueryException(Throwable cause, String errorCode, String errorMessage, String errorClass, String host)
{
super(errorMessage, cause);
this.errorCode = errorCode; this.errorCode = errorCode;
this.errorClass = errorClass; this.errorClass = errorClass;
this.host = host; this.host = host;

View File

@ -95,6 +95,24 @@ public class QueryExceptionTest
expectFailTypeForCode(FailType.USER_ERROR, QueryException.SQL_QUERY_UNSUPPORTED_ERROR_CODE); expectFailTypeForCode(FailType.USER_ERROR, QueryException.SQL_QUERY_UNSUPPORTED_ERROR_CODE);
} }
/**
* This test exists primarily to get branch coverage of the null check on the QueryException constructor.
* The validations done in this test are not actually intended to be set-in-stone or anything.
*/
@Test
public void testCanConstructWithoutThrowable()
{
QueryException exception = new QueryException(
(Throwable) null,
QueryException.UNKNOWN_EXCEPTION_ERROR_CODE,
"java.lang.Exception",
"test"
);
Assert.assertEquals(QueryException.UNKNOWN_EXCEPTION_ERROR_CODE, exception.getErrorCode());
Assert.assertNull(exception.getMessage());
}
private void expectFailTypeForCode(FailType expected, String code) private void expectFailTypeForCode(FailType expected, String code)
{ {
QueryException exception = new QueryException(new Exception(), code, "java.lang.Exception", "test"); QueryException exception = new QueryException(new Exception(), code, "java.lang.Exception", "test");

View File

@ -420,14 +420,20 @@ public class CoordinatorPollingBasicAuthorizerCacheManager implements BasicAutho
new BytesFullResponseHandler() new BytesFullResponseHandler()
); );
final HttpResponseStatus status = responseHolder.getStatus();
// cachedSerializedGroupMappingMap is a new endpoint introduced in Druid 0.17.0. For backwards compatibility, if we // cachedSerializedGroupMappingMap is a new endpoint introduced in Druid 0.17.0. For backwards compatibility, if we
// get a 404 from the coordinator we stop retrying. This can happen during a rolling upgrade when a process // get a 404 from the coordinator we stop retrying. This can happen during a rolling upgrade when a process
// running 0.17.0+ tries to access this endpoint on an older coordinator. // running 0.17.0+ tries to access this endpoint on an older coordinator.
if (responseHolder.getStatus().equals(HttpResponseStatus.NOT_FOUND)) { if (HttpResponseStatus.NOT_FOUND.equals(status)) {
LOG.warn("cachedSerializedGroupMappingMap is not available from the coordinator, skipping fetch of group mappings for now."); LOG.warn("cachedSerializedGroupMappingMap is not available from the coordinator, skipping fetch of group mappings for now.");
return null; return null;
} }
if (!HttpResponseStatus.OK.equals(status)) {
LOG.warn("Got an unexpected response status[%s] when loading group mappings.", status);
}
byte[] groupRoleMapBytes = responseHolder.getContent(); byte[] groupRoleMapBytes = responseHolder.getContent();
GroupMappingAndRoleMap groupMappingAndRoleMap = objectMapper.readValue( GroupMappingAndRoleMap groupMappingAndRoleMap = objectMapper.readValue(

View File

@ -41,7 +41,6 @@ import org.jboss.netty.handler.codec.http.HttpResponseStatus;
import javax.inject.Inject; import javax.inject.Inject;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import java.io.IOException; import java.io.IOException;
import java.net.URL; import java.net.URL;
import java.util.Map; import java.util.Map;
@ -336,7 +335,8 @@ public class DruidClusterClient
*/ */
public void validate() public void validate()
{ {
log.info("Starting cluster validation"); RE exception = new RE("Just building for the stack trace");
log.info(exception, "Starting cluster validation");
for (ResolvedDruidService service : config.requireDruid().values()) { for (ResolvedDruidService service : config.requireDruid().values()) {
for (ResolvedInstance instance : service.requireInstances()) { for (ResolvedInstance instance : service.requireInstances()) {
validateInstance(service, instance); validateInstance(service, instance);
@ -348,28 +348,46 @@ public class DruidClusterClient
/** /**
* Validate an instance by waiting for it to report that it is healthy. * Validate an instance by waiting for it to report that it is healthy.
*/ */
@SuppressWarnings("BusyWait")
private void validateInstance(ResolvedDruidService service, ResolvedInstance instance) private void validateInstance(ResolvedDruidService service, ResolvedInstance instance)
{ {
int timeoutMs = config.readyTimeoutSec() * 1000; int timeoutMs = config.readyTimeoutSec() * 1000;
int pollMs = config.readyPollMs(); int pollMs = config.readyPollMs();
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
long updateTime = startTime + 5000; long updateTime = startTime + 5000;
while (System.currentTimeMillis() - startTime < timeoutMs) { while (true) {
if (isHealthy(service, instance)) { if (isHealthy(service, instance)) {
log.info( log.info(
"Service %s, host %s is ready", "Service[%s], host[%s], tag[%s] is ready",
service.service(), service.service(),
instance.clientHost()); instance.clientHost(),
instance.tag() == null ? "<default>" : instance.tag()
);
return; return;
} }
long currentTime = System.currentTimeMillis(); long currentTime = System.currentTimeMillis();
if (currentTime > updateTime) { if (currentTime > updateTime) {
log.info( log.info(
"Service %s, host %s not ready, retrying", "Service[%s], host[%s], tag[%s] not ready, retrying",
service.service(), service.service(),
instance.clientHost()); instance.clientHost(),
instance.tag() == null ? "<default>" : instance.tag()
);
updateTime = currentTime + 5000; updateTime = currentTime + 5000;
} }
final long elapsedTime = System.currentTimeMillis() - startTime;
if (elapsedTime > timeoutMs) {
final RE exception = new RE(
"Service[%s], host[%s], tag[%s] not ready after %,d ms.",
service.service(),
instance.clientHost(),
instance.tag() == null ? "<default>" : instance.tag(),
elapsedTime
);
// We log the exception here so that the logs include which thread is having the problem
log.error(exception.getMessage());
throw exception;
}
try { try {
Thread.sleep(pollMs); Thread.sleep(pollMs);
} }
@ -377,34 +395,30 @@ public class DruidClusterClient
throw new RuntimeException("Interrupted during cluster validation"); throw new RuntimeException("Interrupted during cluster validation");
} }
} }
throw new RE(
StringUtils.format("Service %s, instance %s not ready after %d ms.",
service.service(),
instance.tag() == null ? "<default>" : instance.tag(),
timeoutMs));
} }
/** /**
* Wait for an instance to become ready given the URL and a description of * Wait for an instance to become ready given the URL and a description of
* the service. * the service.
*/ */
@SuppressWarnings("BusyWait")
public void waitForNodeReady(String label, String url) public void waitForNodeReady(String label, String url)
{ {
int timeoutMs = config.readyTimeoutSec() * 1000; int timeoutMs = config.readyTimeoutSec() * 1000;
int pollMs = config.readyPollMs(); int pollMs = config.readyPollMs();
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() - startTime < timeoutMs) { while (true) {
if (isHealthy(url)) { if (isHealthy(url)) {
log.info( log.info("Service[%s], url[%s] is ready", label, url);
"Service %s, url %s is ready",
label,
url);
return; return;
} }
log.info( final long elapsedTime = System.currentTimeMillis() - startTime;
"Service %s, url %s not ready, retrying", if (elapsedTime > timeoutMs) {
label, final RE re = new RE("Service[%s], url[%s] not ready after %,d ms.", label, url, elapsedTime);
url); log.error(re.getMessage());
throw re;
}
log.info("Service[%s], url[%s] not ready, retrying", label, url);
try { try {
Thread.sleep(pollMs); Thread.sleep(pollMs);
} }
@ -412,11 +426,6 @@ public class DruidClusterClient
throw new RuntimeException("Interrupted while waiting for note to be ready"); throw new RuntimeException("Interrupted while waiting for note to be ready");
} }
} }
throw new RE(
StringUtils.format("Service %s, url %s not ready after %d ms.",
label,
url,
timeoutMs));
} }
public String nodeUrl(DruidNode node) public String nodeUrl(DruidNode node)

View File

@ -29,7 +29,7 @@ public class BadJsonQueryException extends BadQueryException
public BadJsonQueryException(JsonParseException e) public BadJsonQueryException(JsonParseException e)
{ {
this(JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS); this(e, JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS);
} }
@JsonCreator @JsonCreator
@ -39,6 +39,16 @@ public class BadJsonQueryException extends BadQueryException
@JsonProperty("errorClass") String errorClass @JsonProperty("errorClass") String errorClass
) )
{ {
super(errorCode, errorMessage, errorClass); this(null, errorCode, errorMessage, errorClass);
}
private BadJsonQueryException(
Throwable cause,
String errorCode,
String errorMessage,
String errorClass
)
{
super(cause, errorCode, errorMessage, errorClass, null);
} }
} }

View File

@ -30,11 +30,16 @@ public abstract class BadQueryException extends QueryException
protected BadQueryException(String errorCode, String errorMessage, String errorClass) protected BadQueryException(String errorCode, String errorMessage, String errorClass)
{ {
super(errorCode, errorMessage, errorClass, null); this(errorCode, errorMessage, errorClass, null);
} }
protected BadQueryException(String errorCode, String errorMessage, String errorClass, String host) protected BadQueryException(String errorCode, String errorMessage, String errorClass, String host)
{ {
super(errorCode, errorMessage, errorClass, host); this(null, errorCode, errorMessage, errorClass, host);
}
protected BadQueryException(Throwable cause, String errorCode, String errorMessage, String errorClass, String host)
{
super(cause, errorCode, errorMessage, errorClass, host);
} }
} }

View File

@ -30,6 +30,7 @@ import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import com.google.inject.Inject; import com.google.inject.Inject;
import org.apache.druid.client.DirectDruidClient; import org.apache.druid.client.DirectDruidClient;
@ -216,21 +217,8 @@ public class QueryResource implements QueryCountStatsProvider
throw new ForbiddenException(authResult.toString()); throw new ForbiddenException(authResult.toString());
} }
// We use an async context not because we are actually going to run this async, but because we want to delay final QueryResourceQueryResultPusher pusher = new QueryResourceQueryResultPusher(req, queryLifecycle, io);
// the decision of what the response code should be until we have gotten the first few data points to return. return pusher.push();
// Returning a Response object from this point forward requires that object to know the status code, which we
// don't actually know until we are in the accumulator, but if we try to return a Response object from the
// accumulator, we cannot properly stream results back, because the accumulator won't release control of the
// Response until it has consumed the underlying Sequence.
final AsyncContext asyncContext = req.startAsync();
try {
new QueryResourceQueryResultPusher(req, queryLifecycle, io, (HttpServletResponse) asyncContext.getResponse())
.push();
}
finally {
asyncContext.complete();
}
} }
catch (Exception e) { catch (Exception e) {
if (e instanceof ForbiddenException && !req.isAsyncStarted()) { if (e instanceof ForbiddenException && !req.isAsyncStarted()) {
@ -258,6 +246,7 @@ public class QueryResource implements QueryCountStatsProvider
out.write(jsonMapper.writeValueAsBytes(responseException)); out.write(jsonMapper.writeValueAsBytes(responseException));
} }
} }
return null;
} }
finally { finally {
asyncContext.complete(); asyncContext.complete();
@ -266,7 +255,6 @@ public class QueryResource implements QueryCountStatsProvider
finally { finally {
Thread.currentThread().setName(currThreadName); Thread.currentThread().setName(currThreadName);
} }
return null;
} }
public interface QueryMetricCounter public interface QueryMetricCounter
@ -538,18 +526,18 @@ public class QueryResource implements QueryCountStatsProvider
public QueryResourceQueryResultPusher( public QueryResourceQueryResultPusher(
HttpServletRequest req, HttpServletRequest req,
QueryLifecycle queryLifecycle, QueryLifecycle queryLifecycle,
ResourceIOReaderWriter io, ResourceIOReaderWriter io
HttpServletResponse response
) )
{ {
super( super(
response, req,
QueryResource.this.jsonMapper, QueryResource.this.jsonMapper,
QueryResource.this.responseContextConfig, QueryResource.this.responseContextConfig,
QueryResource.this.selfNode, QueryResource.this.selfNode,
QueryResource.this.counter, QueryResource.this.counter,
queryLifecycle.getQueryId(), queryLifecycle.getQueryId(),
MediaType.valueOf(io.getResponseWriter().getResponseType()) MediaType.valueOf(io.getResponseWriter().getResponseType()),
ImmutableMap.of()
); );
this.req = req; this.req = req;
this.queryLifecycle = queryLifecycle; this.queryLifecycle = queryLifecycle;
@ -561,20 +549,27 @@ public class QueryResource implements QueryCountStatsProvider
{ {
return new ResultsWriter() return new ResultsWriter()
{ {
private QueryResponse<Object> queryResponse;
@Override @Override
public QueryResponse<Object> start(HttpServletResponse response) public Response.ResponseBuilder start()
{ {
final QueryResponse<Object> queryResponse = queryLifecycle.execute(); queryResponse = queryLifecycle.execute();
final ResponseContext responseContext = queryResponse.getResponseContext(); final ResponseContext responseContext = queryResponse.getResponseContext();
final String prevEtag = getPreviousEtag(req); final String prevEtag = getPreviousEtag(req);
if (prevEtag != null && prevEtag.equals(responseContext.getEntityTag())) { if (prevEtag != null && prevEtag.equals(responseContext.getEntityTag())) {
queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1); queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1);
counter.incrementSuccess(); counter.incrementSuccess();
response.setStatus(HttpServletResponse.SC_NOT_MODIFIED); return Response.status(Status.NOT_MODIFIED);
return null;
} }
return null;
}
@Override
public QueryResponse<Object> getQueryResponse()
{
return queryResponse; return queryResponse;
} }

View File

@ -36,45 +36,54 @@ import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.server.security.ForbiddenException;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.ServletOutputStream; import javax.servlet.ServletOutputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.StreamingOutput;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.Map;
public abstract class QueryResultPusher public abstract class QueryResultPusher
{ {
private static final Logger log = new Logger(QueryResultPusher.class); private static final Logger log = new Logger(QueryResultPusher.class);
private final HttpServletResponse response; private final HttpServletRequest request;
private final String queryId; private final String queryId;
private final ObjectMapper jsonMapper; private final ObjectMapper jsonMapper;
private final ResponseContextConfig responseContextConfig; private final ResponseContextConfig responseContextConfig;
private final DruidNode selfNode; private final DruidNode selfNode;
private final QueryResource.QueryMetricCounter counter; private final QueryResource.QueryMetricCounter counter;
private final MediaType contentType; private final MediaType contentType;
private final Map<String, String> extraHeaders;
private StreamingHttpResponseAccumulator accumulator = null; private StreamingHttpResponseAccumulator accumulator = null;
private AsyncContext asyncContext = null;
private HttpServletResponse response = null;
public QueryResultPusher( public QueryResultPusher(
HttpServletResponse response, HttpServletRequest request,
ObjectMapper jsonMapper, ObjectMapper jsonMapper,
ResponseContextConfig responseContextConfig, ResponseContextConfig responseContextConfig,
DruidNode selfNode, DruidNode selfNode,
QueryResource.QueryMetricCounter counter, QueryResource.QueryMetricCounter counter,
String queryId, String queryId,
MediaType contentType MediaType contentType,
Map<String, String> extraHeaders
) )
{ {
this.response = response; this.request = request;
this.queryId = queryId; this.queryId = queryId;
this.jsonMapper = jsonMapper; this.jsonMapper = jsonMapper;
this.responseContextConfig = responseContextConfig; this.responseContextConfig = responseContextConfig;
this.selfNode = selfNode; this.selfNode = selfNode;
this.counter = counter; this.counter = counter;
this.contentType = contentType; this.contentType = contentType;
this.extraHeaders = extraHeaders;
} }
/** /**
@ -92,23 +101,45 @@ public abstract class QueryResultPusher
public abstract void writeException(Exception e, OutputStream out) throws IOException; public abstract void writeException(Exception e, OutputStream out) throws IOException;
public void push() /**
* Pushes results out. Can sometimes return a JAXRS Response object instead of actually pushing to the output
* stream, primarily for error handling that occurs before switching the servlet to asynchronous mode.
*
* @return null if the response has already been handled and pushed out, or a non-null Response object if it expects
* the container to put the bytes on the wire.
*/
@Nullable
public Response push()
{ {
response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
ResultsWriter resultsWriter = null; ResultsWriter resultsWriter = null;
try { try {
resultsWriter = start(); resultsWriter = start();
final Response.ResponseBuilder startResponse = resultsWriter.start();
final QueryResponse<Object> queryResponse = resultsWriter.start(response); if (startResponse != null) {
if (queryResponse == null) { startResponse.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
// It's already been handled... for (Map.Entry<String, String> entry : extraHeaders.entrySet()) {
return; startResponse.header(entry.getKey(), entry.getValue());
}
return startResponse.build();
} }
final QueryResponse<Object> queryResponse = resultsWriter.getQueryResponse();
final Sequence<Object> results = queryResponse.getResults(); final Sequence<Object> results = queryResponse.getResults();
// We use an async context not because we are actually going to run this async, but because we want to delay
// the decision of what the response code should be until we have gotten the first few data points to return.
// Returning a Response object from this point forward requires that object to know the status code, which we
// don't actually know until we are in the accumulator, but if we try to return a Response object from the
// accumulator, we cannot properly stream results back, because the accumulator won't release control of the
// Response until it has consumed the underlying Sequence.
asyncContext = request.startAsync();
response = (HttpServletResponse) asyncContext.getResponse();
response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
for (Map.Entry<String, String> entry : extraHeaders.entrySet()) {
response.setHeader(entry.getKey(), entry.getValue());
}
accumulator = new StreamingHttpResponseAccumulator(queryResponse.getResponseContext(), resultsWriter); accumulator = new StreamingHttpResponseAccumulator(queryResponse.getResponseContext(), resultsWriter);
results.accumulate(null, accumulator); results.accumulate(null, accumulator);
@ -119,8 +150,7 @@ public abstract class QueryResultPusher
resultsWriter.recordSuccess(accumulator.getNumBytesSent()); resultsWriter.recordSuccess(accumulator.getNumBytesSent());
} }
catch (QueryException e) { catch (QueryException e) {
handleQueryException(resultsWriter, e); return handleQueryException(resultsWriter, e);
return;
} }
catch (RuntimeException re) { catch (RuntimeException re) {
if (re instanceof ForbiddenException) { if (re instanceof ForbiddenException) {
@ -128,17 +158,15 @@ public abstract class QueryResultPusher
// has been committed because the response is committed after results are returned. And, if we started // has been committed because the response is committed after results are returned. And, if we started
// returning results before a ForbiddenException gets thrown, that means that we've already leaked stuff // returning results before a ForbiddenException gets thrown, that means that we've already leaked stuff
// that should not have been leaked. I.e. it means, we haven't validated the authorization early enough. // that should not have been leaked. I.e. it means, we haven't validated the authorization early enough.
if (response.isCommitted()) { if (response != null && response.isCommitted()) {
log.error(re, "Got a forbidden exception for query[%s] after the response was already committed.", queryId); log.error(re, "Got a forbidden exception for query[%s] after the response was already committed.", queryId);
} }
throw re; throw re;
} }
handleQueryException(resultsWriter, new QueryInterruptedException(re)); return handleQueryException(resultsWriter, new QueryInterruptedException(re));
return;
} }
catch (IOException ioEx) { catch (IOException ioEx) {
handleQueryException(resultsWriter, new QueryInterruptedException(ioEx)); return handleQueryException(resultsWriter, new QueryInterruptedException(ioEx));
return;
} }
finally { finally {
if (accumulator != null) { if (accumulator != null) {
@ -159,10 +187,15 @@ public abstract class QueryResultPusher
log.warn(e, "Suppressing exception closing accumulator for query[%s]", queryId); log.warn(e, "Suppressing exception closing accumulator for query[%s]", queryId);
} }
} }
if (asyncContext != null) {
asyncContext.complete();
}
} }
return null;
} }
private void handleQueryException(ResultsWriter resultsWriter, QueryException e) @Nullable
private Response handleQueryException(ResultsWriter resultsWriter, QueryException e)
{ {
if (accumulator != null && accumulator.isInitialized()) { if (accumulator != null && accumulator.isInitialized()) {
// We already started sending a response when we got the error message. In this case we just give up // We already started sending a response when we got the error message. In this case we just give up
@ -176,11 +209,7 @@ public abstract class QueryResultPusher
// This case is always a failure because the error happened mid-stream of sending results back. Therefore, // This case is always a failure because the error happened mid-stream of sending results back. Therefore,
// we do not believe that the response stream was actually useable // we do not believe that the response stream was actually useable
counter.incrementFailed(); counter.incrementFailed();
return; return null;
}
if (response.isCommitted()) {
QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?");
} }
final QueryException.FailType failType = e.getFailType(); final QueryException.FailType failType = e.getFailType();
@ -206,40 +235,71 @@ public abstract class QueryResultPusher
); );
counter.incrementFailed(); counter.incrementFailed();
} }
final int responseStatus = failType.getExpectedStatus();
response.setStatus(responseStatus);
response.setHeader("Content-Type", contentType.toString());
try (ServletOutputStream out = response.getOutputStream()) {
writeException(e, out);
}
catch (IOException ioException) {
log.warn(
ioException,
"Suppressing IOException thrown sending error response for query[%s]",
queryId
);
}
resultsWriter.recordFailure(e); resultsWriter.recordFailure(e);
final int responseStatus = failType.getExpectedStatus();
if (response == null) {
// No response object yet, so assume we haven't started the async context and is safe to return Response
final Response.ResponseBuilder bob = Response
.status(responseStatus)
.type(contentType)
.entity((StreamingOutput) output -> {
writeException(e, output);
output.close();
});
bob.header(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId);
for (Map.Entry<String, String> entry : extraHeaders.entrySet()) {
bob.header(entry.getKey(), entry.getValue());
}
return bob.build();
} else {
if (response.isCommitted()) {
QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?");
}
response.setStatus(responseStatus);
response.setHeader("Content-Type", contentType.toString());
try (ServletOutputStream out = response.getOutputStream()) {
writeException(e, out);
}
catch (IOException ioException) {
log.warn(
ioException,
"Suppressing IOException thrown sending error response for query[%s]",
queryId
);
}
return null;
}
} }
public interface ResultsWriter extends Closeable public interface ResultsWriter extends Closeable
{ {
/** /**
* Runs the query and returns a ResultsWriter from running the query. * Runs the query and prepares the QueryResponse to be returned
* <p> * <p>
* This also serves as a hook for any logic that runs on the metadata from a QueryResponse. If this method * This also serves as a hook for any logic that runs on the metadata from a QueryResponse. If this method
* returns {@code null} then the Pusher believes that the response was already handled and skips the rest * returns {@code null} then the Pusher can continue with normal logic. If this method chooses to return
* of its logic. As such, any implementation that returns null must make sure that the response has been set * a ResponseBuilder, then the Pusher will attach any extra metadata it has to the Response and return
* with a meaningful status, etc. * the response built from the Builder without attempting to process the results of the query.
* <p> * <p>
* Even if this method returns null, close() should still be called on this object. * In all cases, {@link #close()} should be called on this object.
* *
* @return QueryResponse or null if no more work to do. * @return QueryResponse or null if no more work to do.
*/ */
@Nullable @Nullable
QueryResponse<Object> start(HttpServletResponse response); Response.ResponseBuilder start();
/**
* Gets the results of running the query. {@link #start} must be called before this method is called.
*
* @return the results of running the query as preparted by the {@link #start()} method
*/
QueryResponse<Object> getQueryResponse();
Writer makeWriter(OutputStream out) throws IOException; Writer makeWriter(OutputStream out) throws IOException;
@ -301,9 +361,7 @@ public abstract class QueryResultPusher
* Initializes the response. This is done lazily so that we can put various metadata that we only get once * Initializes the response. This is done lazily so that we can put various metadata that we only get once
* we have some of the response stream into the result. * we have some of the response stream into the result.
* <p> * <p>
* This is called once for each result object, but should only actually happen once. * It is okay for this to be called multiple times.
*
* @return boolean if initialization occurred. False most of the team because initialization only happens once.
*/ */
public void initialize() public void initialize()
{ {
@ -332,7 +390,7 @@ public abstract class QueryResultPusher
); );
} }
catch (JsonProcessingException e) { catch (JsonProcessingException e) {
QueryResource.log.info(e, "Problem serializing to JSON!?"); log.info(e, "Problem serializing to JSON!?");
serializationResult = new ResponseContext.SerializationResult("Could not serialize", "Could not serialize"); serializationResult = new ResponseContext.SerializationResult("Could not serialize", "Could not serialize");
} }
@ -343,7 +401,7 @@ public abstract class QueryResultPusher
serializationResult.getFullResult() serializationResult.getFullResult()
); );
if (responseContextConfig.shouldFailOnTruncatedResponseContext()) { if (responseContextConfig.shouldFailOnTruncatedResponseContext()) {
QueryResource.log.error(logToPrint); log.error(logToPrint);
throw new QueryInterruptedException( throw new QueryInterruptedException(
new TruncatedResponseContextException( new TruncatedResponseContextException(
"Serialized response context exceeds the max size[%s]", "Serialized response context exceeds the max size[%s]",
@ -352,12 +410,12 @@ public abstract class QueryResultPusher
selfNode.getHostAndPortToUse() selfNode.getHostAndPortToUse()
); );
} else { } else {
QueryResource.log.warn(logToPrint); log.warn(logToPrint);
} }
} }
response.setHeader(QueryResource.HEADER_RESPONSE_CONTEXT, serializationResult.getResult()); response.setHeader(QueryResource.HEADER_RESPONSE_CONTEXT, serializationResult.getResult());
response.setHeader("Content-Type", contentType.toString()); response.setContentType(contentType.toString());
try { try {
out = new CountingOutputStream(response.getOutputStream()); out = new CountingOutputStream(response.getOutputStream());
@ -379,6 +437,7 @@ public abstract class QueryResultPusher
} }
@Override @Override
@Nullable
public Response accumulate(Response retVal, Object in) public Response accumulate(Response retVal, Object in)
{ {
if (!initialized) { if (!initialized) {

View File

@ -20,11 +20,12 @@
package org.apache.druid.server.security; package org.apache.druid.server.security;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.server.DruidNode; import org.apache.druid.server.DruidNode;
import org.apache.druid.server.QueryResource;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterChain; import javax.servlet.FilterChain;
@ -83,12 +84,15 @@ public class PreResponseAuthorizationCheckFilter implements Filter
filterChain.doFilter(servletRequest, servletResponse); filterChain.doFilter(servletRequest, servletResponse);
Boolean authInfoChecked = (Boolean) servletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED); Boolean authInfoChecked = (Boolean) servletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED);
if (authInfoChecked == null && statusIsSuccess(response.getStatus())) { if (authInfoChecked == null && statusShouldBeHidden(response.getStatus())) {
// Note: rather than throwing an exception here, it would be nice to blank out the original response // Note: rather than throwing an exception here, it would be nice to blank out the original response
// since the request didn't have any authorization checks performed. However, this breaks proxying // since the request didn't have any authorization checks performed. However, this breaks proxying
// (e.g. OverlordServletProxy), so this is not implemented for now. // (e.g. OverlordServletProxy), so this is not implemented for now.
handleAuthorizationCheckError( handleAuthorizationCheckError(
"Request did not have an authorization check performed.", StringUtils.format(
"Request did not have an authorization check performed, original response status[%s].",
response.getStatus()
),
request, request,
response response
); );
@ -136,7 +140,6 @@ public class PreResponseAuthorizationCheckFilter implements Filter
OutputStream out = response.getOutputStream(); OutputStream out = response.getOutputStream();
sendJsonError(response, HttpServletResponse.SC_UNAUTHORIZED, jsonMapper.writeValueAsString(unauthorizedError), out); sendJsonError(response, HttpServletResponse.SC_UNAUTHORIZED, jsonMapper.writeValueAsString(unauthorizedError), out);
out.close(); out.close();
return;
} }
private void handleAuthorizationCheckError( private void handleAuthorizationCheckError(
@ -145,19 +148,21 @@ public class PreResponseAuthorizationCheckFilter implements Filter
HttpServletResponse servletResponse HttpServletResponse servletResponse
) )
{ {
final String queryId = servletResponse.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER);
// Send out an alert so there's a centralized collection point for seeing errors of this nature // Send out an alert so there's a centralized collection point for seeing errors of this nature
log.makeAlert(errorMsg) log.makeAlert(errorMsg)
.addData("uri", servletRequest.getRequestURI()) .addData("uri", servletRequest.getRequestURI())
.addData("method", servletRequest.getMethod()) .addData("method", servletRequest.getMethod())
.addData("remoteAddr", servletRequest.getRemoteAddr()) .addData("remoteAddr", servletRequest.getRemoteAddr())
.addData("remoteHost", servletRequest.getRemoteHost()) .addData("remoteHost", servletRequest.getRemoteHost())
.addData("queryId", queryId)
.emit(); .emit();
if (servletResponse.isCommitted()) { if (!servletResponse.isCommitted()) {
throw new ISE(errorMsg);
} else {
try { try {
servletResponse.sendError(HttpServletResponse.SC_FORBIDDEN); servletResponse.reset();
servletResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
} }
catch (Exception e) { catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -165,9 +170,23 @@ public class PreResponseAuthorizationCheckFilter implements Filter
} }
} }
private static boolean statusIsSuccess(int status) private static boolean statusShouldBeHidden(int status)
{ {
return 200 <= status && status < 300; // We allow 404s (not found) to not be rewritten to forbidden because consistently returning 404s is a way to leak
// less information when something wasn't able to be done anyway. I.e. if we pretend that the thing didn't exist
// when the authorization fails, then there is no information about whether the thing existed. If we return
// a 403 when authorization fails and a 404 when authorization succeeds, but it doesn't exist, then we have
// leaked that it could maybe exist, if the authentication credentials were good.
//
// We also allow 307s (temporary redirect) to not be hidden as they are used to redirect to the leader.
switch (status) {
case HttpServletResponse.SC_FORBIDDEN:
case HttpServletResponse.SC_NOT_FOUND:
case HttpServletResponse.SC_TEMPORARY_REDIRECT:
return false;
default:
return true;
}
} }
public static void sendJsonError(HttpServletResponse resp, int error, String errorJson, OutputStream outputStream) public static void sendJsonError(HttpServletResponse resp, int error, String errorJson, OutputStream outputStream)

View File

@ -26,7 +26,6 @@ import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import com.google.inject.Injector; import com.google.inject.Injector;
@ -88,7 +87,9 @@ import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.core.MediaType; import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.StreamingOutput;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.Collection; import java.util.Collection;
@ -396,7 +397,7 @@ public class QueryResourceTest
final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY); final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY);
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
//since accept header is null, the response content type should be same as the value of 'Content-Type' header //since accept header is null, the response content type should be same as the value of 'Content-Type' header
Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type"))); Assert.assertEquals(MediaType.APPLICATION_JSON, response.getContentType());
} }
@Test @Test
@ -409,7 +410,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
//since accept header is empty, the response content type should be same as the value of 'Content-Type' header //since accept header is empty, the response content type should be same as the value of 'Content-Type' header
Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type"))); Assert.assertEquals(MediaType.APPLICATION_JSON, response.getContentType());
} }
@Test @Test
@ -424,10 +425,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
// Content-Type in response should be Smile // Content-Type in response should be Smile
Assert.assertEquals( Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType());
SmileMediaTypes.APPLICATION_JACKSON_SMILE,
Iterables.getOnlyElement(response.headers.get("Content-Type"))
);
} }
@Test @Test
@ -447,10 +445,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
// Content-Type in response should be Smile // Content-Type in response should be Smile
Assert.assertEquals( Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType());
SmileMediaTypes.APPLICATION_JACKSON_SMILE,
Iterables.getOnlyElement(response.headers.get("Content-Type"))
);
} }
@Test @Test
@ -469,10 +464,7 @@ public class QueryResourceTest
Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus());
// Content-Type in response should default to Content-Type from request // Content-Type in response should default to Content-Type from request
Assert.assertEquals( Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, response.getContentType());
SmileMediaTypes.APPLICATION_JACKSON_SMILE,
Iterables.getOnlyElement(response.headers.get("Content-Type"))
);
} }
@Test @Test
@ -643,13 +635,16 @@ public class QueryResourceTest
); );
expectPermissiveHappyPathAuth(); expectPermissiveHappyPathAuth();
final MockHttpServletResponse response = expectAsyncRequestFlow( final Response response = expectSynchronousRequestFlow(
testServletRequest, testServletRequest,
SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8), SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8),
timeoutQueryResource timeoutQueryResource
); );
Assert.assertEquals(QueryTimeoutException.STATUS_CODE, response.getStatus()); Assert.assertEquals(QueryTimeoutException.STATUS_CODE, response.getStatus());
QueryTimeoutException ex = jsonMapper.readValue(response.baos.toByteArray(), QueryTimeoutException.class);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
QueryTimeoutException ex = jsonMapper.readValue(baos.toByteArray(), QueryTimeoutException.class);
Assert.assertEquals("Query Timed Out!", ex.getMessage()); Assert.assertEquals("Query Timed Out!", ex.getMessage());
Assert.assertEquals(QueryException.QUERY_TIMEOUT_ERROR_CODE, ex.getErrorCode()); Assert.assertEquals(QueryException.QUERY_TIMEOUT_ERROR_CODE, ex.getErrorCode());
Assert.assertEquals(1, timeoutQueryResource.getTimedOutQueryCount()); Assert.assertEquals(1, timeoutQueryResource.getTimedOutQueryCount());
@ -892,25 +887,28 @@ public class QueryResourceTest
); );
createScheduledQueryResource(laningScheduler, Collections.emptyList(), ImmutableList.of(waitTwoScheduled)); createScheduledQueryResource(laningScheduler, Collections.emptyList(), ImmutableList.of(waitTwoScheduled));
assertResponseAndCountdownOrBlockForever( assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY, SIMPLE_TIMESERIES_QUERY,
waitAllFinished, waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
); );
assertResponseAndCountdownOrBlockForever( assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY, SIMPLE_TIMESERIES_QUERY,
waitAllFinished, waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
); );
waitTwoScheduled.await(); waitTwoScheduled.await();
assertResponseAndCountdownOrBlockForever( assertSynchronousResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY, SIMPLE_TIMESERIES_QUERY,
waitAllFinished, waitAllFinished,
response -> { response -> {
Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus());
QueryCapacityExceededException ex; QueryCapacityExceededException ex;
try { try {
ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class);
} }
catch (IOException e) { catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -938,20 +936,22 @@ public class QueryResourceTest
createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled)); createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled));
assertResponseAndCountdownOrBlockForever( assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY, SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY,
waitAllFinished, waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
); );
waitOneScheduled.await(); waitOneScheduled.await();
assertResponseAndCountdownOrBlockForever( assertSynchronousResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY, SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY,
waitAllFinished, waitAllFinished,
response -> { response -> {
Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus());
QueryCapacityExceededException ex; QueryCapacityExceededException ex;
try { try {
ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class);
} }
catch (IOException e) { catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -965,7 +965,7 @@ public class QueryResourceTest
} }
); );
waitTwoStarted.await(); waitTwoStarted.await();
assertResponseAndCountdownOrBlockForever( assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY, SIMPLE_TIMESERIES_QUERY,
waitAllFinished, waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
@ -990,20 +990,22 @@ public class QueryResourceTest
createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled)); createScheduledQueryResource(scheduler, ImmutableList.of(waitTwoStarted), ImmutableList.of(waitOneScheduled));
assertResponseAndCountdownOrBlockForever( assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY, SIMPLE_TIMESERIES_QUERY,
waitAllFinished, waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
); );
waitOneScheduled.await(); waitOneScheduled.await();
assertResponseAndCountdownOrBlockForever( assertSynchronousResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY, SIMPLE_TIMESERIES_QUERY,
waitAllFinished, waitAllFinished,
response -> { response -> {
Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus());
QueryCapacityExceededException ex; QueryCapacityExceededException ex;
try { try {
ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) response.getEntity()).write(baos);
ex = jsonMapper.readValue(baos.toByteArray(), QueryCapacityExceededException.class);
} }
catch (IOException e) { catch (IOException e) {
throw new RuntimeException(e); throw new RuntimeException(e);
@ -1016,7 +1018,7 @@ public class QueryResourceTest
} }
); );
waitTwoStarted.await(); waitTwoStarted.await();
assertResponseAndCountdownOrBlockForever( assertAsyncResponseAndCountdownOrBlockForever(
SIMPLE_TIMESERIES_QUERY_SMALLISH_INTERVAL, SIMPLE_TIMESERIES_QUERY_SMALLISH_INTERVAL,
waitAllFinished, waitAllFinished,
response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()) response -> Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus())
@ -1085,7 +1087,7 @@ public class QueryResourceTest
); );
} }
private void assertResponseAndCountdownOrBlockForever( private void assertAsyncResponseAndCountdownOrBlockForever(
String query, String query,
CountDownLatch done, CountDownLatch done,
Consumer<MockHttpServletResponse> asserts Consumer<MockHttpServletResponse> asserts
@ -1146,4 +1148,36 @@ public class QueryResourceTest
)); ));
return response; return response;
} }
private void assertSynchronousResponseAndCountdownOrBlockForever(
String query,
CountDownLatch done,
Consumer<Response> asserts
)
{
Executors.newSingleThreadExecutor().submit(() -> {
try {
asserts.accept(
expectSynchronousRequestFlow(
testServletRequest.mimic(),
query.getBytes(StandardCharsets.UTF_8),
queryResource
)
);
}
catch (IOException e) {
throw new RuntimeException(e);
}
done.countDown();
});
}
private Response expectSynchronousRequestFlow(
MockHttpServletRequest req,
byte[] bytes,
QueryResource queryResource
) throws IOException
{
return queryResource.doPost(new ByteArrayInputStream(bytes), null, req);
}
} }

View File

@ -20,141 +20,155 @@
package org.apache.druid.server.http.security; package org.apache.druid.server.http.security;
import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.server.metrics.NoopServiceEmitter;
import org.apache.druid.server.mocks.MockHttpServletRequest;
import org.apache.druid.server.mocks.MockHttpServletResponse;
import org.apache.druid.server.security.AllowAllAuthenticator; import org.apache.druid.server.security.AllowAllAuthenticator;
import org.apache.druid.server.security.AuthConfig; import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.Authenticator; import org.apache.druid.server.security.Authenticator;
import org.apache.druid.server.security.PreResponseAuthorizationCheckFilter; import org.apache.druid.server.security.PreResponseAuthorizationCheckFilter;
import org.easymock.EasyMock; import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import javax.servlet.FilterChain; import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream; import java.io.IOException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
public class PreResponseAuthorizationCheckFilterTest public class PreResponseAuthorizationCheckFilterTest
{ {
private static List<Authenticator> authenticators = Collections.singletonList(new AllowAllAuthenticator()); private static final List<Authenticator> authenticators = Collections.singletonList(new AllowAllAuthenticator());
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Test @Test
public void testValidRequest() throws Exception public void testValidRequest() throws Exception
{ {
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); MockHttpServletRequest req = new MockHttpServletRequest();
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); MockHttpServletResponse resp = new MockHttpServletResponse();
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once(); req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(true).once(); req.attributes.put(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true);
EasyMock.replay(req, resp, filterChain, outputStream);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators, authenticators,
new DefaultObjectMapper() new DefaultObjectMapper()
); );
filter.doFilter(req, resp, filterChain); filter.doFilter(req, resp, (request, response) -> {
EasyMock.verify(req, resp, filterChain, outputStream); });
} }
@Test @Test
public void testAuthenticationFailedRequest() throws Exception public void testAuthenticationFailedRequest() throws Exception
{ {
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); MockHttpServletRequest req = new MockHttpServletRequest();
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); MockHttpServletResponse resp = new MockHttpServletResponse();
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
EasyMock.expect(resp.getOutputStream()).andReturn(outputStream).once();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(null).once();
resp.setStatus(401);
EasyMock.expectLastCall().once();
resp.setContentType("application/json");
EasyMock.expectLastCall().once();
resp.setCharacterEncoding("UTF-8");
EasyMock.expectLastCall().once();
EasyMock.replay(req, resp, filterChain, outputStream);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators, authenticators,
new DefaultObjectMapper() new DefaultObjectMapper()
); );
filter.doFilter(req, resp, filterChain); filter.doFilter(req, resp, (request, response) -> {
EasyMock.verify(req, resp, filterChain, outputStream); });
Assert.assertEquals(401, resp.getStatus());
Assert.assertEquals("application/json", resp.getContentType());
Assert.assertEquals("UTF-8", resp.getCharacterEncoding());
} }
@Test @Test
public void testMissingAuthorizationCheck() throws Exception public void testMissingAuthorizationCheckAndNotCommitted() throws ServletException, IOException
{ {
EmittingLogger.registerEmitter(EasyMock.createNiceMock(ServiceEmitter.class)); EmittingLogger.registerEmitter(new NoopServiceEmitter());
expectedException.expect(ISE.class);
expectedException.expectMessage("Request did not have an authorization check performed.");
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); MockHttpServletRequest req = new MockHttpServletRequest();
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); req.requestUri = "uri";
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class); req.method = "GET";
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class); req.remoteAddr = "1.2.3.4";
req.remoteHost = "aHost";
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once(); MockHttpServletResponse resp = new MockHttpServletResponse();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(null).once(); resp.setStatus(200);
EasyMock.expect(resp.getStatus()).andReturn(200).once();
EasyMock.expect(req.getRequestURI()).andReturn("uri").once(); req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
EasyMock.expect(req.getMethod()).andReturn("GET").once();
EasyMock.expect(req.getRemoteAddr()).andReturn("1.2.3.4").once(); PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
EasyMock.expect(req.getRemoteHost()).andReturn("ahostname").once(); authenticators,
EasyMock.expect(resp.isCommitted()).andReturn(true).once(); new DefaultObjectMapper()
);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(403, resp.getStatus());
}
@Test
public void testMissingAuthorizationCheckWithForbidden() throws Exception
{
EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
MockHttpServletResponse resp = new MockHttpServletResponse();
resp.setStatus(403); resp.setStatus(403);
EasyMock.expectLastCall().once();
resp.setContentType("application/json");
EasyMock.expectLastCall().once();
resp.setCharacterEncoding("UTF-8");
EasyMock.expectLastCall().once();
EasyMock.replay(req, resp, filterChain, outputStream);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators, authenticators,
new DefaultObjectMapper() new DefaultObjectMapper()
); );
filter.doFilter(req, resp, filterChain); filter.doFilter(req, resp, (request, response) -> {
EasyMock.verify(req, resp, filterChain, outputStream); });
Assert.assertEquals(403, resp.getStatus());
} }
@Test @Test
public void testMissingAuthorizationCheckWithError() throws Exception public void testMissingAuthorizationCheckWith404Keeps404() throws Exception
{ {
EmittingLogger.registerEmitter(EasyMock.createNiceMock(ServiceEmitter.class)); EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null); AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); MockHttpServletRequest req = new MockHttpServletRequest();
HttpServletResponse resp = EasyMock.createStrictMock(HttpServletResponse.class); req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
FilterChain filterChain = EasyMock.createNiceMock(FilterChain.class);
ServletOutputStream outputStream = EasyMock.createNiceMock(ServletOutputStream.class);
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)).andReturn(authenticationResult).once(); MockHttpServletResponse resp = new MockHttpServletResponse();
EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)).andReturn(null).once(); resp.setStatus(404);
EasyMock.expect(resp.getStatus()).andReturn(404).once();
EasyMock.replay(req, resp, filterChain, outputStream);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter( PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators, authenticators,
new DefaultObjectMapper() new DefaultObjectMapper()
); );
filter.doFilter(req, resp, filterChain); filter.doFilter(req, resp, (request, response) -> {
EasyMock.verify(req, resp, filterChain, outputStream); });
Assert.assertEquals(404, resp.getStatus());
}
@Test
public void testMissingAuthorizationCheckWith307Keeps307() throws Exception
{
EmittingLogger.registerEmitter(new NoopServiceEmitter());
AuthenticationResult authenticationResult = new AuthenticationResult("so-very-valid", "so-very-valid", null, null);
MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
MockHttpServletResponse resp = new MockHttpServletResponse();
resp.setStatus(307);
PreResponseAuthorizationCheckFilter filter = new PreResponseAuthorizationCheckFilter(
authenticators,
new DefaultObjectMapper()
);
filter.doFilter(req, resp, (request, response) -> {
});
Assert.assertEquals(307, resp.getStatus());
} }
} }

View File

@ -54,9 +54,11 @@ import java.util.function.Supplier;
*/ */
public class MockHttpServletRequest implements HttpServletRequest public class MockHttpServletRequest implements HttpServletRequest
{ {
public String requestUri = null;
public String method = null; public String method = null;
public String contentType = null; public String contentType = null;
public String remoteAddr = null; public String remoteAddr = null;
public String remoteHost = null;
public LinkedHashMap<String, String> headers = new LinkedHashMap<>(); public LinkedHashMap<String, String> headers = new LinkedHashMap<>();
public LinkedHashMap<String, Object> attributes = new LinkedHashMap<>(); public LinkedHashMap<String, Object> attributes = new LinkedHashMap<>();
@ -110,7 +112,7 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override @Override
public String getMethod() public String getMethod()
{ {
return method; return unsupportedIfNull(method);
} }
@Override @Override
@ -164,7 +166,7 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override @Override
public String getRequestURI() public String getRequestURI()
{ {
throw new UnsupportedOperationException(); return unsupportedIfNull(requestUri);
} }
@Override @Override
@ -296,7 +298,7 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override @Override
public String getContentType() public String getContentType()
{ {
return contentType; return unsupportedIfNull(contentType);
} }
@Override @Override
@ -362,13 +364,13 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override @Override
public String getRemoteAddr() public String getRemoteAddr()
{ {
return remoteAddr; return unsupportedIfNull(remoteAddr);
} }
@Override @Override
public String getRemoteHost() public String getRemoteHost()
{ {
throw new UnsupportedOperationException(); return unsupportedIfNull(remoteHost);
} }
@Override @Override
@ -486,6 +488,9 @@ public class MockHttpServletRequest implements HttpServletRequest
@Override @Override
public AsyncContext getAsyncContext() public AsyncContext getAsyncContext()
{ {
if (currAsyncContext == null) {
throw new IllegalStateException("Must be put into Async mode before async context can be gottendid");
}
return currAsyncContext; return currAsyncContext;
} }
@ -514,4 +519,13 @@ public class MockHttpServletRequest implements HttpServletRequest
return retVal; return retVal;
} }
private <T> T unsupportedIfNull(T obj)
{
if (obj == null) {
throw new UnsupportedOperationException();
} else {
return obj;
}
}
} }

View File

@ -22,12 +22,12 @@ package org.apache.druid.server.mocks;
import com.google.common.collect.Multimap; import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps; import com.google.common.collect.Multimaps;
import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.servlet.ServletOutputStream; import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener; import javax.servlet.WriteListener;
import javax.servlet.http.Cookie; import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import javax.validation.constraints.NotNull;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.PrintWriter; import java.io.PrintWriter;
import java.util.ArrayList; import java.util.ArrayList;
@ -63,6 +63,21 @@ public class MockHttpServletResponse implements HttpServletResponse
private int statusCode; private int statusCode;
private String contentType; private String contentType;
private String characterEncoding;
@Override
public void reset()
{
if (isCommitted()) {
throw new IllegalStateException("Cannot reset a committed ServletResponse");
}
headers.clear();
statusCode = 0;
contentType = null;
characterEncoding = null;
}
@Override @Override
public void addCookie(Cookie cookie) public void addCookie(Cookie cookie)
@ -198,7 +213,7 @@ public class MockHttpServletResponse implements HttpServletResponse
@Override @Override
public String getCharacterEncoding() public String getCharacterEncoding()
{ {
throw new UnsupportedOperationException(); return characterEncoding;
} }
@Override @Override
@ -231,13 +246,13 @@ public class MockHttpServletResponse implements HttpServletResponse
} }
@Override @Override
public void write(@Nonnull byte[] b) public void write(@NotNull byte[] b)
{ {
baos.write(b, 0, b.length); baos.write(b, 0, b.length);
} }
@Override @Override
public void write(@Nonnull byte[] b, int off, int len) public void write(@NotNull byte[] b, int off, int len)
{ {
baos.write(b, off, len); baos.write(b, off, len);
} }
@ -253,7 +268,7 @@ public class MockHttpServletResponse implements HttpServletResponse
@Override @Override
public void setCharacterEncoding(String charset) public void setCharacterEncoding(String charset)
{ {
throw new UnsupportedOperationException(); characterEncoding = charset;
} }
@Override @Override
@ -298,18 +313,19 @@ public class MockHttpServletResponse implements HttpServletResponse
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
public void forceCommitted()
{
if (!isCommitted()) {
baos.write(1234);
}
}
@Override @Override
public boolean isCommitted() public boolean isCommitted()
{ {
return baos.size() > 0; return baos.size() > 0;
} }
@Override
public void reset()
{
throw new UnsupportedOperationException();
}
@Override @Override
public void setLocale(Locale loc) public void setLocale(Locale loc)
{ {

View File

@ -61,22 +61,27 @@ public class SqlPlanningException extends BadQueryException
public SqlPlanningException(SqlParseException e) public SqlPlanningException(SqlParseException e)
{ {
this(PlanningError.SQL_PARSE_ERROR, e.getMessage()); this(e, PlanningError.SQL_PARSE_ERROR, e.getMessage());
} }
public SqlPlanningException(ValidationException e) public SqlPlanningException(ValidationException e)
{ {
this(PlanningError.VALIDATION_ERROR, e.getMessage()); this(e, PlanningError.VALIDATION_ERROR, e.getMessage());
} }
public SqlPlanningException(CalciteContextException e) public SqlPlanningException(CalciteContextException e)
{ {
this(PlanningError.VALIDATION_ERROR, e.getMessage()); this(e, PlanningError.VALIDATION_ERROR, e.getMessage());
} }
public SqlPlanningException(PlanningError planningError, String errorMessage) public SqlPlanningException(PlanningError planningError, String errorMessage)
{ {
this(planningError.errorCode, errorMessage, planningError.errorClass); this(null, planningError, errorMessage);
}
public SqlPlanningException(Throwable cause, PlanningError planningError, String errorMessage)
{
this(cause, planningError.errorCode, errorMessage, planningError.errorClass);
} }
@JsonCreator @JsonCreator
@ -86,6 +91,17 @@ public class SqlPlanningException extends BadQueryException
@JsonProperty("errorClass") String errorClass @JsonProperty("errorClass") String errorClass
) )
{ {
super(errorCode, errorMessage, errorClass); this(null, errorCode, errorMessage, errorClass);
} }
private SqlPlanningException(
Throwable cause,
String errorCode,
String errorMessage,
String errorClass
)
{
super(cause, errorCode, errorMessage, errorClass, null);
}
} }

View File

@ -48,9 +48,7 @@ import org.apache.druid.sql.SqlRowTransformer;
import org.apache.druid.sql.SqlStatementFactory; import org.apache.druid.sql.SqlStatementFactory;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.servlet.AsyncContext;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.Consumes; import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE; import javax.ws.rs.DELETE;
import javax.ws.rs.POST; import javax.ws.rs.POST;
@ -63,7 +61,9 @@ import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -121,22 +121,8 @@ public class SqlResource
try { try {
Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId)); Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId));
// We use an async context not because we are actually going to run this async, but because we want to delay QueryResultPusher pusher = makePusher(req, stmt, sqlQuery);
// the decision of what the response code should be until we have gotten the first few data points to return. return pusher.push();
// Returning a Response object from this point forward requires that object to know the status code, which we
// don't actually know until we are in the accumulator, but if we try to return a Response object from the
// accumulator, we cannot properly stream results back, because the accumulator won't release control of the
// Response until it has consumed the underlying Sequence.
final AsyncContext asyncContext = req.startAsync();
try {
QueryResultPusher pusher = new SqlResourceQueryResultPusher(asyncContext, sqlQueryId, stmt, sqlQuery);
pusher.push();
return null;
}
finally {
asyncContext.complete();
}
} }
finally { finally {
Thread.currentThread().setName(currThreadName); Thread.currentThread().setName(currThreadName);
@ -213,27 +199,43 @@ public class SqlResource
} }
} }
private SqlResourceQueryResultPusher makePusher(HttpServletRequest req, HttpStatement stmt, SqlQuery sqlQuery)
{
final String sqlQueryId = stmt.sqlQueryId();
Map<String, String> headers = new LinkedHashMap<>();
headers.put(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId);
if (sqlQuery.includeHeader()) {
headers.put(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE);
}
return new SqlResourceQueryResultPusher(req, sqlQueryId, stmt, sqlQuery, headers);
}
private class SqlResourceQueryResultPusher extends QueryResultPusher private class SqlResourceQueryResultPusher extends QueryResultPusher
{ {
private final String sqlQueryId; private final String sqlQueryId;
private final HttpStatement stmt; private final HttpStatement stmt;
private final SqlQuery sqlQuery; private final SqlQuery sqlQuery;
public SqlResourceQueryResultPusher( public SqlResourceQueryResultPusher(
AsyncContext asyncContext, HttpServletRequest req,
String sqlQueryId, String sqlQueryId,
HttpStatement stmt, HttpStatement stmt,
SqlQuery sqlQuery SqlQuery sqlQuery,
Map<String, String> headers
) )
{ {
super( super(
(HttpServletResponse) asyncContext.getResponse(), req,
SqlResource.this.jsonMapper, SqlResource.this.jsonMapper,
SqlResource.this.responseContextConfig, SqlResource.this.responseContextConfig,
SqlResource.this.selfNode, SqlResource.this.selfNode,
SqlResource.QUERY_METRIC_COUNTER, SqlResource.QUERY_METRIC_COUNTER,
sqlQueryId, sqlQueryId,
MediaType.APPLICATION_JSON_TYPE MediaType.APPLICATION_JSON_TYPE,
headers
); );
this.sqlQueryId = sqlQueryId; this.sqlQueryId = sqlQueryId;
this.stmt = stmt; this.stmt = stmt;
@ -245,19 +247,17 @@ public class SqlResource
{ {
return new ResultsWriter() return new ResultsWriter()
{ {
private QueryResponse<Object[]> queryResponse;
private ResultSet thePlan; private ResultSet thePlan;
@Override @Override
@Nullable @Nullable
@SuppressWarnings({"unchecked", "rawtypes"}) public Response.ResponseBuilder start()
public QueryResponse<Object> start(HttpServletResponse response)
{ {
response.setHeader(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId);
final QueryResponse<Object[]> retVal;
try { try {
thePlan = stmt.plan(); thePlan = stmt.plan();
retVal = thePlan.run(); queryResponse = thePlan.run();
return null;
} }
catch (RelOptPlanner.CannotPlanException e) { catch (RelOptPlanner.CannotPlanException e) {
throw new SqlPlanningException( throw new SqlPlanningException(
@ -276,12 +276,13 @@ public class SqlResource
// doesn't implement org.apache.druid.common.exception.SanitizableException. // doesn't implement org.apache.druid.common.exception.SanitizableException.
throw new QueryInterruptedException(e); throw new QueryInterruptedException(e);
} }
}
if (sqlQuery.includeHeader()) { @Override
response.setHeader(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); @SuppressWarnings({"unchecked", "rawtypes"})
} public QueryResponse<Object> getQueryResponse()
{
return (QueryResponse) retVal; return (QueryResponse) queryResponse;
} }
@Override @Override
@ -343,6 +344,11 @@ public class SqlResource
@Override @Override
public void recordFailure(Exception e) public void recordFailure(Exception e)
{ {
if (sqlQuery.queryContext().isDebug()) {
log.warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId);
} else {
log.noStackTrace().warn(e, "Exception while processing sqlQueryId[%s]", sqlQueryId);
}
stmt.reporter().failed(e); stmt.reporter().failed(e);
} }

View File

@ -22,7 +22,6 @@ package org.apache.druid.sql.http;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter; import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
@ -39,6 +38,7 @@ import org.apache.druid.jackson.DefaultObjectMapper;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.NonnullPair; import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.guava.LazySequence; import org.apache.druid.java.util.common.guava.LazySequence;
@ -61,6 +61,7 @@ import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.query.groupby.GroupByQueryConfig; import org.apache.druid.query.groupby.GroupByQueryConfig;
import org.apache.druid.server.DruidNode; import org.apache.druid.server.DruidNode;
import org.apache.druid.server.QueryResource;
import org.apache.druid.server.QueryResponse; import org.apache.druid.server.QueryResponse;
import org.apache.druid.server.QueryScheduler; import org.apache.druid.server.QueryScheduler;
import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.QueryStackTests;
@ -109,9 +110,14 @@ import org.junit.Test;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.Response; import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status; import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.StreamingOutput;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.AbstractList; import java.util.AbstractList;
import java.util.ArrayList; import java.util.ArrayList;
@ -128,6 +134,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class SqlResourceTest extends CalciteTestBase public class SqlResourceTest extends CalciteTestBase
@ -175,7 +182,7 @@ public class SqlResourceTest extends CalciteTestBase
private final SettableSupplier<ResponseContext> responseContextSupplier = new SettableSupplier<>(); private final SettableSupplier<ResponseContext> responseContextSupplier = new SettableSupplier<>();
private Consumer<DirectStatement> onExecute = NULL_ACTION; private Consumer<DirectStatement> onExecute = NULL_ACTION;
private boolean sleep; private Supplier<Void> schedulerBaggage = () -> null;
@Before @Before
public void setUp() throws Exception public void setUp() throws Exception
@ -195,15 +202,8 @@ public class SqlResourceTest extends CalciteTestBase
{ {
return super.run( return super.run(
query, query,
new LazySequence<T>(() -> { new LazySequence<>(() -> {
if (sleep) { schedulerBaggage.get();
try {
// pretend to be a query that is waiting on results
Thread.sleep(500);
}
catch (InterruptedException ignored) {
}
}
return resultSequence; return resultSequence;
}) })
); );
@ -329,7 +329,7 @@ public class SqlResourceTest extends CalciteTestBase
public void testUnauthorized() public void testUnauthorized()
{ {
try { try {
postForResponse( postForAsyncResponse(
createSimpleQueryWithId("id", "select count(*) from forbiddenDatasource"), createSimpleQueryWithId("id", "select count(*) from forbiddenDatasource"),
request() request()
); );
@ -380,18 +380,17 @@ public class SqlResourceTest extends CalciteTestBase
mockRespContext.put(ResponseContext.Keys.instance().keyOf("uncoveredIntervalsOverflowed"), "true"); mockRespContext.put(ResponseContext.Keys.instance().keyOf("uncoveredIntervalsOverflowed"), "true");
responseContextSupplier.set(mockRespContext); responseContextSupplier.set(mockRespContext);
final MockHttpServletResponse response = postForResponse(sqlQuery, makeRegularUserReq()); final MockHttpServletResponse response = postForAsyncResponse(sqlQuery, makeRegularUserReq());
Map responseContext = JSON_MAPPER.readValue(
Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")),
Map.class
);
Assert.assertEquals( Assert.assertEquals(
ImmutableMap.of( ImmutableMap.of(
"uncoveredIntervals", "2030-01-01/78149827981274-01-01", "uncoveredIntervals", "2030-01-01/78149827981274-01-01",
"uncoveredIntervalsOverflowed", "true" "uncoveredIntervalsOverflowed", "true"
), ),
responseContext JSON_MAPPER.readValue(
Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")),
Map.class
)
); );
Object results = JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class); Object results = JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class);
@ -714,6 +713,7 @@ public class SqlResourceTest extends CalciteTestBase
} }
@Test @Test
@SuppressWarnings("rawtypes")
public void testArrayResultFormatWithHeader() throws Exception public void testArrayResultFormatWithHeader() throws Exception
{ {
final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2";
@ -725,7 +725,7 @@ public class SqlResourceTest extends CalciteTestBase
Arrays.asList("2000-01-02T00:00:00.000Z", "10.1", nullStr, "[\"b\",\"c\"]", 1, 2.0, 2.0, hllStr, nullStr) Arrays.asList("2000-01-02T00:00:00.000Z", "10.1", nullStr, "[\"b\",\"c\"]", 1, 2.0, 2.0, hllStr, nullStr)
}; };
MockHttpServletResponse response = postForResponse( MockHttpServletResponse response = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null),
req.mimic() req.mimic()
); );
@ -745,7 +745,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class) JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class)
); );
MockHttpServletResponse responseNoSqlTypesHeader = postForResponse( MockHttpServletResponse responseNoSqlTypesHeader = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, true, false, null, null), new SqlQuery(query, ResultFormat.ARRAY, true, true, false, null, null),
req.mimic() req.mimic()
); );
@ -764,7 +764,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(responseNoSqlTypesHeader.baos.toByteArray(), Object.class) JSON_MAPPER.readValue(responseNoSqlTypesHeader.baos.toByteArray(), Object.class)
); );
MockHttpServletResponse responseNoTypesHeader = postForResponse( MockHttpServletResponse responseNoTypesHeader = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, false, true, null, null), new SqlQuery(query, ResultFormat.ARRAY, true, false, true, null, null),
req.mimic() req.mimic()
); );
@ -783,7 +783,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(responseNoTypesHeader.baos.toByteArray(), Object.class) JSON_MAPPER.readValue(responseNoTypesHeader.baos.toByteArray(), Object.class)
); );
MockHttpServletResponse responseNoTypes = postForResponse( MockHttpServletResponse responseNoTypes = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, false, false, null, null), new SqlQuery(query, ResultFormat.ARRAY, true, false, false, null, null),
req.mimic() req.mimic()
); );
@ -801,7 +801,7 @@ public class SqlResourceTest extends CalciteTestBase
JSON_MAPPER.readValue(responseNoTypes.baos.toByteArray(), Object.class) JSON_MAPPER.readValue(responseNoTypes.baos.toByteArray(), Object.class)
); );
MockHttpServletResponse responseNoHeader = postForResponse( MockHttpServletResponse responseNoHeader = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, false, false, false, null, null), new SqlQuery(query, ResultFormat.ARRAY, false, false, false, null, null),
req.mimic() req.mimic()
); );
@ -821,7 +821,7 @@ public class SqlResourceTest extends CalciteTestBase
// Test a query that returns null header for some of the columns // Test a query that returns null header for some of the columns
final String query = "SELECT (1, 2) FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1"; final String query = "SELECT (1, 2) FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1";
MockHttpServletResponse response = postForResponse( MockHttpServletResponse response = postForAsyncResponse(
new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null),
req req
); );
@ -971,12 +971,10 @@ public class SqlResourceTest extends CalciteTestBase
{ {
final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2";
final String nullStr = NullHandling.replaceWithDefault() ? "" : null; final String nullStr = NullHandling.replaceWithDefault() ? "" : null;
final Function<Map<String, Object>, Map<String, Object>> transformer = m -> { final Function<Map<String, Object>, Map<String, Object>> transformer = m -> Maps.transformEntries(
return Maps.transformEntries( m,
m, (k, v) -> "EXPR$8".equals(k) || ("dim2".equals(k) && v.toString().isEmpty()) ? nullStr : v
(k, v) -> "EXPR$8".equals(k) || ("dim2".equals(k) && v.toString().isEmpty()) ? nullStr : v );
);
};
Assert.assertEquals( Assert.assertEquals(
ImmutableList.of( ImmutableList.of(
@ -1336,9 +1334,7 @@ public class SqlResourceTest extends CalciteTestBase
@Test @Test
public void testCannotParse() throws Exception public void testCannotParse() throws Exception
{ {
final QueryException exception = doPost( QueryException exception = postSyncForException("FROM druid.foo", Status.BAD_REQUEST.getStatusCode());
createSimpleQueryWithId("id", "FROM druid.foo")
).lhs;
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.SQL_PARSE_ERROR.getErrorCode(), exception.getErrorCode()); Assert.assertEquals(PlanningError.SQL_PARSE_ERROR.getErrorCode(), exception.getErrorCode());
@ -1351,9 +1347,7 @@ public class SqlResourceTest extends CalciteTestBase
@Test @Test
public void testCannotValidate() throws Exception public void testCannotValidate() throws Exception
{ {
final QueryException exception = doPost( QueryException exception = postSyncForException("SELECT dim4 FROM druid.foo", Status.BAD_REQUEST.getStatusCode());
createSimpleQueryWithId("id", "SELECT dim4 FROM druid.foo")
).lhs;
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode()); Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode());
@ -1367,10 +1361,10 @@ public class SqlResourceTest extends CalciteTestBase
public void testCannotConvert() throws Exception public void testCannotConvert() throws Exception
{ {
// SELECT + ORDER unsupported // SELECT + ORDER unsupported
final QueryException exception = doPost( final SqlQuery unsupportedQuery = createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1");
createSimpleQueryWithId("id", "SELECT dim1 FROM druid.foo ORDER BY dim1") QueryException exception = postSyncForException(unsupportedQuery, Status.BAD_REQUEST.getStatusCode());
).lhs;
Assert.assertTrue((Boolean) req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED));
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertEquals("SQL query is unsupported", exception.getErrorCode()); Assert.assertEquals("SQL query is unsupported", exception.getErrorCode());
Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorClass(), exception.getErrorClass()); Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorClass(), exception.getErrorClass());
@ -1392,9 +1386,10 @@ public class SqlResourceTest extends CalciteTestBase
public void testCannotConvert_UnsupportedSQLQueryException() throws Exception public void testCannotConvert_UnsupportedSQLQueryException() throws Exception
{ {
// max(string) unsupported // max(string) unsupported
final QueryException exception = doPost( QueryException exception = postSyncForException(
createSimpleQueryWithId("id", "SELECT max(dim1) FROM druid.foo") "SELECT max(dim1) FROM druid.foo",
).lhs; Status.BAD_REQUEST.getStatusCode()
);
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorCode(), exception.getErrorCode()); Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorCode(), exception.getErrorCode());
@ -1442,7 +1437,7 @@ public class SqlResourceTest extends CalciteTestBase
{ {
String errorMessage = "This will be supported in Druid 9999"; String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage); failOnExecute(errorMessage);
final QueryException exception = doPost( QueryException exception = postSyncForException(
new SqlQuery( new SqlQuery(
"SELECT ANSWER TO LIFE", "SELECT ANSWER TO LIFE",
ResultFormat.OBJECT, ResultFormat.OBJECT,
@ -1451,8 +1446,9 @@ public class SqlResourceTest extends CalciteTestBase
false, false,
ImmutableMap.of(BaseQuery.SQL_QUERY_ID, "id"), ImmutableMap.of(BaseQuery.SQL_QUERY_ID, "id"),
null null
) ),
).lhs; 501
);
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertEquals(QueryException.QUERY_UNSUPPORTED_ERROR_CODE, exception.getErrorCode()); Assert.assertEquals(QueryException.QUERY_UNSUPPORTED_ERROR_CODE, exception.getErrorCode());
@ -1466,7 +1462,7 @@ public class SqlResourceTest extends CalciteTestBase
String queryId = "id123"; String queryId = "id123";
String errorMessage = "This will be supported in Druid 9999"; String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage); failOnExecute(errorMessage);
final MockHttpServletResponse response = postForResponse( final Response response = postForSyncResponse(
new SqlQuery( new SqlQuery(
"SELECT ANSWER TO LIFE", "SELECT ANSWER TO LIFE",
ResultFormat.OBJECT, ResultFormat.OBJECT,
@ -1478,11 +1474,12 @@ public class SqlResourceTest extends CalciteTestBase
), ),
req req
); );
Assert.assertNotEquals(200, response.getStatus());
Assert.assertEquals( // This is checked in the common method that returns the response, but checking it again just protects
queryId, // from changes there breaking the checks, so doesn't hurt.
Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)) assertStatusAndCommonHeaders(response, 501);
); Assert.assertEquals(queryId, getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER));
Assert.assertEquals(queryId, getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER));
} }
@Test @Test
@ -1490,7 +1487,7 @@ public class SqlResourceTest extends CalciteTestBase
{ {
String errorMessage = "This will be supported in Druid 9999"; String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage); failOnExecute(errorMessage);
final MockHttpServletResponse response = postForResponse( final Response response = postForSyncResponse(
new SqlQuery( new SqlQuery(
"SELECT ANSWER TO LIFE", "SELECT ANSWER TO LIFE",
ResultFormat.OBJECT, ResultFormat.OBJECT,
@ -1502,10 +1499,10 @@ public class SqlResourceTest extends CalciteTestBase
), ),
req req
); );
Assert.assertNotEquals(200, response.getStatus());
Assert.assertFalse( // This is checked in the common method that returns the response, but checking it again just protects
Strings.isNullOrEmpty(Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER))) // from changes there breaking the checks, so doesn't hurt.
); assertStatusAndCommonHeaders(response, 501);
} }
@Test @Test
@ -1536,7 +1533,7 @@ public class SqlResourceTest extends CalciteTestBase
String errorMessage = "This will be supported in Druid 9999"; String errorMessage = "This will be supported in Druid 9999";
failOnExecute(errorMessage); failOnExecute(errorMessage);
final QueryException exception = doPost( QueryException exception = postSyncForException(
new SqlQuery( new SqlQuery(
"SELECT ANSWER TO LIFE", "SELECT ANSWER TO LIFE",
ResultFormat.OBJECT, ResultFormat.OBJECT,
@ -1545,8 +1542,9 @@ public class SqlResourceTest extends CalciteTestBase
false, false,
ImmutableMap.of("sqlQueryId", "id"), ImmutableMap.of("sqlQueryId", "id"),
null null
) ),
).lhs; 501
);
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertNull(exception.getMessage()); Assert.assertNull(exception.getMessage());
@ -1587,7 +1585,7 @@ public class SqlResourceTest extends CalciteTestBase
onExecute = s -> { onExecute = s -> {
throw new AssertionError(errorMessage); throw new AssertionError(errorMessage);
}; };
final QueryException exception = doPost( QueryException exception = postSyncForException(
new SqlQuery( new SqlQuery(
"SELECT ANSWER TO LIFE", "SELECT ANSWER TO LIFE",
ResultFormat.OBJECT, ResultFormat.OBJECT,
@ -1596,8 +1594,9 @@ public class SqlResourceTest extends CalciteTestBase
false, false,
ImmutableMap.of("sqlQueryId", "id"), ImmutableMap.of("sqlQueryId", "id"),
null null
) ),
).lhs; Status.INTERNAL_SERVER_ERROR.getStatusCode()
);
Assert.assertNotNull(exception); Assert.assertNotNull(exception);
Assert.assertNull(exception.getMessage()); Assert.assertNull(exception.getMessage());
@ -1610,15 +1609,28 @@ public class SqlResourceTest extends CalciteTestBase
@Test @Test
public void testTooManyRequests() throws Exception public void testTooManyRequests() throws Exception
{ {
sleep = true;
final int numQueries = 3; final int numQueries = 3;
CountDownLatch queriesScheduledLatch = new CountDownLatch(numQueries - 1);
CountDownLatch runQueryLatch = new CountDownLatch(1);
schedulerBaggage = () -> {
queriesScheduledLatch.countDown();
try {
runQueryLatch.await();
}
catch (InterruptedException e) {
throw new RE(e);
}
return null;
};
final String sqlQueryId = "tooManyRequestsTest"; final String sqlQueryId = "tooManyRequestsTest";
List<Future<Pair<QueryException, List<Map<String, Object>>>>> futures = new ArrayList<>(numQueries); List<Future<Object>> futures = new ArrayList<>(numQueries);
for (int i = 0; i < numQueries; i++) { for (int i = 0; i < numQueries - 1; i++) {
futures.add(executorService.submit(() -> { futures.add(executorService.submit(() -> {
try { try {
return doPost( return postForAsyncResponse(
new SqlQuery( new SqlQuery(
"SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo", "SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo",
null, null,
@ -1637,22 +1649,51 @@ public class SqlResourceTest extends CalciteTestBase
})); }));
} }
queriesScheduledLatch.await();
schedulerBaggage = () -> null;
futures.add(executorService.submit(() -> {
try {
final Response retVal = postForSyncResponse(
new SqlQuery(
"SELECT COUNT(*) AS cnt, 'foo' AS TheFoo FROM druid.foo",
null,
false,
false,
false,
ImmutableMap.of("priority", -5, BaseQuery.SQL_QUERY_ID, sqlQueryId),
null
),
makeRegularUserReq()
);
runQueryLatch.countDown();
return retVal;
}
catch (Exception e) {
throw new RuntimeException(e);
}
}));
int success = 0; int success = 0;
int limited = 0; int limited = 0;
for (int i = 0; i < numQueries; i++) { for (int i = 0; i < numQueries; i++) {
Pair<QueryException, List<Map<String, Object>>> result = futures.get(i).get(); if (i == 2) {
List<Map<String, Object>> rows = result.rhs; Response response = (Response) futures.get(i).get();
if (rows != null) { assertStatusAndCommonHeaders(response, 429);
Assert.assertEquals(ImmutableList.of(ImmutableMap.of("cnt", 6, "TheFoo", "foo")), rows); QueryException interruped = deserializeResponse(response, QueryException.class);
success++;
} else {
QueryException interruped = result.lhs;
Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, interruped.getErrorCode()); Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, interruped.getErrorCode());
Assert.assertEquals( Assert.assertEquals(
QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 2), QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 2),
interruped.getMessage() interruped.getMessage()
); );
limited++; limited++;
} else {
MockHttpServletResponse response = (MockHttpServletResponse) futures.get(i).get();
assertStatusAndCommonHeaders(response, 200);
Assert.assertEquals(
ImmutableList.of(ImmutableMap.of("cnt", 6, "TheFoo", "foo")),
deserializeResponse(response, Object.class)
);
success++;
} }
} }
Assert.assertEquals(2, success); Assert.assertEquals(2, success);
@ -1671,7 +1712,8 @@ public class SqlResourceTest extends CalciteTestBase
BaseQuery.SQL_QUERY_ID, BaseQuery.SQL_QUERY_ID,
sqlQueryId sqlQueryId
); );
final QueryException timeoutException = doPost(
QueryException exception = postSyncForException(
new SqlQuery( new SqlQuery(
"SELECT CAST(__time AS DATE), dim1, dim2, dim3 FROM druid.foo GROUP by __time, dim1, dim2, dim3 ORDER BY dim2 DESC", "SELECT CAST(__time AS DATE), dim1, dim2, dim3 FROM druid.foo GROUP by __time, dim1, dim2, dim3 ORDER BY dim2 DESC",
ResultFormat.OBJECT, ResultFormat.OBJECT,
@ -1680,11 +1722,13 @@ public class SqlResourceTest extends CalciteTestBase
false, false,
queryContext, queryContext,
null null
) ),
).lhs; 504
Assert.assertNotNull(timeoutException); );
Assert.assertEquals(timeoutException.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE);
Assert.assertEquals(timeoutException.getErrorClass(), QueryTimeoutException.class.getName()); Assert.assertNotNull(exception);
Assert.assertEquals(exception.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE);
Assert.assertEquals(exception.getErrorClass(), QueryTimeoutException.class.getName());
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
} }
@ -1697,8 +1741,8 @@ public class SqlResourceTest extends CalciteTestBase
validateAndAuthorizeLatchSupplier.set(new NonnullPair<>(validateAndAuthorizeLatch, true)); validateAndAuthorizeLatchSupplier.set(new NonnullPair<>(validateAndAuthorizeLatch, true));
CountDownLatch planLatch = new CountDownLatch(1); CountDownLatch planLatch = new CountDownLatch(1);
planLatchSupplier.set(new NonnullPair<>(planLatch, false)); planLatchSupplier.set(new NonnullPair<>(planLatch, false));
Future<MockHttpServletResponse> future = executorService.submit( Future<Response> future = executorService.submit(
() -> postForResponse( () -> postForSyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"),
makeRegularUserReq() makeRegularUserReq()
) )
@ -1711,9 +1755,10 @@ public class SqlResourceTest extends CalciteTestBase
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
MockHttpServletResponse queryResponse = future.get(); Response queryResponse = future.get();
Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus()); assertStatusAndCommonHeaders(queryResponse, Status.INTERNAL_SERVER_ERROR.getStatusCode());
QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class);
QueryException exception = deserializeResponse(queryResponse, QueryException.class);
Assert.assertEquals("Query cancelled", exception.getErrorCode()); Assert.assertEquals("Query cancelled", exception.getErrorCode());
} }
@ -1725,8 +1770,8 @@ public class SqlResourceTest extends CalciteTestBase
planLatchSupplier.set(new NonnullPair<>(planLatch, true)); planLatchSupplier.set(new NonnullPair<>(planLatch, true));
CountDownLatch execLatch = new CountDownLatch(1); CountDownLatch execLatch = new CountDownLatch(1);
executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); executeLatchSupplier.set(new NonnullPair<>(execLatch, false));
Future<MockHttpServletResponse> future = executorService.submit( Future<Response> future = executorService.submit(
() -> postForResponse( () -> postForSyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"),
makeRegularUserReq() makeRegularUserReq()
) )
@ -1738,9 +1783,10 @@ public class SqlResourceTest extends CalciteTestBase
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
MockHttpServletResponse queryResponse = future.get(); Response queryResponse = future.get();
Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus()); assertStatusAndCommonHeaders(queryResponse, Status.INTERNAL_SERVER_ERROR.getStatusCode());
QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class);
QueryException exception = deserializeResponse(queryResponse, QueryException.class);
Assert.assertEquals("Query cancelled", exception.getErrorCode()); Assert.assertEquals("Query cancelled", exception.getErrorCode());
} }
@ -1753,7 +1799,7 @@ public class SqlResourceTest extends CalciteTestBase
CountDownLatch execLatch = new CountDownLatch(1); CountDownLatch execLatch = new CountDownLatch(1);
executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); executeLatchSupplier.set(new NonnullPair<>(execLatch, false));
Future<MockHttpServletResponse> future = executorService.submit( Future<MockHttpServletResponse> future = executorService.submit(
() -> postForResponse( () -> postForAsyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"),
makeRegularUserReq() makeRegularUserReq()
) )
@ -1778,7 +1824,7 @@ public class SqlResourceTest extends CalciteTestBase
CountDownLatch execLatch = new CountDownLatch(1); CountDownLatch execLatch = new CountDownLatch(1);
executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); executeLatchSupplier.set(new NonnullPair<>(execLatch, false));
Future<MockHttpServletResponse> future = executorService.submit( Future<MockHttpServletResponse> future = executorService.submit(
() -> postForResponse( () -> postForAsyncResponse(
createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM forbiddenDatasource"), createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM forbiddenDatasource"),
makeSuperUserReq() makeSuperUserReq()
) )
@ -1827,23 +1873,16 @@ public class SqlResourceTest extends CalciteTestBase
public void testQueryContextKeyNotAllowed() throws Exception public void testQueryContextKeyNotAllowed() throws Exception
{ {
Map<String, Object> queryContext = ImmutableMap.of(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, "all"); Map<String, Object> queryContext = ImmutableMap.of(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY, "all");
final QueryException queryContextException = doPost( QueryException exception = postSyncForException(
new SqlQuery( new SqlQuery("SELECT 1337", ResultFormat.OBJECT, false, false, false, queryContext, null),
"SELECT 1337", Status.BAD_REQUEST.getStatusCode()
ResultFormat.OBJECT, );
false,
false, Assert.assertNotNull(exception);
false, Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), exception.getErrorCode());
queryContext,
null
)
).lhs;
Assert.assertNotNull(queryContextException);
Assert.assertEquals(PlanningError.VALIDATION_ERROR.getErrorCode(), queryContextException.getErrorCode());
MatcherAssert.assertThat( MatcherAssert.assertThat(
queryContextException.getMessage(), exception.getMessage(),
CoreMatchers.containsString( CoreMatchers.containsString("Cannot execute query with context parameter [sqlInsertSegmentGranularity]")
"Cannot execute query with context parameter [sqlInsertSegmentGranularity]")
); );
checkSqlRequestLog(false); checkSqlRequestLog(false);
} }
@ -1922,7 +1961,7 @@ public class SqlResourceTest extends CalciteTestBase
private Pair<QueryException, String> doPostRaw(final SqlQuery query, final MockHttpServletRequest req) private Pair<QueryException, String> doPostRaw(final SqlQuery query, final MockHttpServletRequest req)
throws Exception throws Exception
{ {
MockHttpServletResponse response = postForResponse(query, req); MockHttpServletResponse response = postForAsyncResponse(query, req);
if (response.getStatus() == 200) { if (response.getStatus() == 200) {
return Pair.of(null, new String(response.baos.toByteArray(), StandardCharsets.UTF_8)); return Pair.of(null, new String(response.baos.toByteArray(), StandardCharsets.UTF_8));
@ -1932,14 +1971,122 @@ public class SqlResourceTest extends CalciteTestBase
} }
@Nonnull @Nonnull
private MockHttpServletResponse postForResponse(SqlQuery query, MockHttpServletRequest req) private MockHttpServletResponse postForAsyncResponse(SqlQuery query, MockHttpServletRequest req)
{ {
MockHttpServletResponse response = MockHttpServletResponse.forRequest(req); MockHttpServletResponse response = MockHttpServletResponse.forRequest(req);
final Object explicitQueryId = query.getContext().get("queryId");
final Object explicitSqlQueryId = query.getContext().get("sqlQueryId");
Assert.assertNull(resource.doPost(query, req)); Assert.assertNull(resource.doPost(query, req));
final Object actualQueryId = response.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER);
final Object actualSqlQueryId = response.getHeader(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER);
validateQueryIds(explicitQueryId, explicitSqlQueryId, actualQueryId, actualSqlQueryId);
return response; return response;
} }
private void assertStatusAndCommonHeaders(MockHttpServletResponse queryResponse, int statusCode)
{
Assert.assertEquals(statusCode, queryResponse.getStatus());
Assert.assertEquals("application/json", queryResponse.getContentType());
Assert.assertNotNull(queryResponse.getHeader(QueryResource.QUERY_ID_RESPONSE_HEADER));
Assert.assertNotNull(queryResponse.getHeader(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER));
}
private <T> T deserializeResponse(MockHttpServletResponse resp, Class<T> clazz) throws IOException
{
return JSON_MAPPER.readValue(resp.baos.toByteArray(), clazz);
}
private Response postForSyncResponse(SqlQuery query, MockHttpServletRequest req)
{
final Object explicitQueryId = query.getContext().get("queryId");
final Object explicitSqlQueryId = query.getContext().get("sqlQueryId");
final Response response = resource.doPost(query, req);
final Object actualQueryId = getHeader(response, QueryResource.QUERY_ID_RESPONSE_HEADER);
final Object actualSqlQueryId = getHeader(response, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER);
validateQueryIds(explicitQueryId, explicitSqlQueryId, actualQueryId, actualSqlQueryId);
return response;
}
private QueryException postSyncForException(String s, int expectedStatus) throws IOException
{
return postSyncForException(createSimpleQueryWithId("id", s), expectedStatus);
}
private QueryException postSyncForException(SqlQuery query, int expectedStatus) throws IOException
{
final Response response = postForSyncResponse(query, req);
assertStatusAndCommonHeaders(response, expectedStatus);
return deserializeResponse(response, QueryException.class);
}
private <T> T deserializeResponse(Response resp, Class<T> clazz) throws IOException
{
return JSON_MAPPER.readValue(responseToByteArray(resp), clazz);
}
private byte[] responseToByteArray(Response resp) throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
((StreamingOutput) resp.getEntity()).write(baos);
return baos.toByteArray();
}
private String getContentType(Response resp)
{
return getHeader(resp, HttpHeaders.CONTENT_TYPE).toString();
}
@Nullable
private Object getHeader(Response resp, String header)
{
final List<Object> objects = resp.getMetadata().get(header);
if (objects == null) {
return null;
}
return Iterables.getOnlyElement(objects);
}
private void assertStatusAndCommonHeaders(Response queryResponse, int statusCode)
{
Assert.assertEquals(statusCode, queryResponse.getStatus());
Assert.assertEquals("application/json", getContentType(queryResponse));
Assert.assertNotNull(getHeader(queryResponse, QueryResource.QUERY_ID_RESPONSE_HEADER));
Assert.assertNotNull(getHeader(queryResponse, SqlResource.SQL_QUERY_ID_RESPONSE_HEADER));
}
private void validateQueryIds(
Object explicitQueryId,
Object explicitSqlQueryId,
Object actualQueryId,
Object actualSqlQueryId
)
{
if (explicitQueryId == null) {
if (null != explicitSqlQueryId) {
Assert.assertEquals(explicitSqlQueryId, actualQueryId);
Assert.assertEquals(explicitSqlQueryId, actualSqlQueryId);
} else {
Assert.assertNotNull(actualQueryId);
Assert.assertNotNull(actualSqlQueryId);
}
} else {
if (explicitSqlQueryId == null) {
Assert.assertEquals(explicitQueryId, actualQueryId);
Assert.assertEquals(explicitQueryId, actualSqlQueryId);
} else {
Assert.assertEquals(explicitQueryId, actualQueryId);
Assert.assertEquals(explicitSqlQueryId, actualSqlQueryId);
}
}
}
private MockHttpServletRequest makeSuperUserReq() private MockHttpServletRequest makeSuperUserReq()
{ {
return makeExpectedReq(CalciteTests.SUPER_USER_AUTH_RESULT); return makeExpectedReq(CalciteTests.SUPER_USER_AUTH_RESULT);
@ -1954,6 +2101,7 @@ public class SqlResourceTest extends CalciteTestBase
{ {
MockHttpServletRequest req = new MockHttpServletRequest(); MockHttpServletRequest req = new MockHttpServletRequest();
req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult);
req.remoteAddr = "1.2.3.4";
return req; return req;
} }