diff --git a/core/src/main/java/org/apache/druid/query/QueryException.java b/core/src/main/java/org/apache/druid/query/QueryException.java index 10267b52539..93f17b6cff9 100644 --- a/core/src/main/java/org/apache/druid/query/QueryException.java +++ b/core/src/main/java/org/apache/druid/query/QueryException.java @@ -21,7 +21,6 @@ package org.apache.druid.query; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; -import com.google.common.annotations.VisibleForTesting; import org.apache.druid.common.exception.SanitizableException; import javax.annotation.Nullable; @@ -30,12 +29,123 @@ import java.net.InetAddress; import java.util.function.Function; /** - * Base serializable error response - * + * Base serializable error response. + *

+ * The Object Model that QueryException follows is a little non-intuitive as the primary way that a QueryException is + * generated is through a child class. However, those child classes are *not* equivalent to a QueryException, instead + * they act as a Factory of QueryException objects. This can be seen in two different places. + *

+ * 1. When sanitize() is called, the response is a QueryException without any indication of which original exception + * occurred. + * 2. When these objects get serialized across the wire the recipient deserializes a QueryException. The client is + * never expected, and fundamentally is not allowed to, ever deserialize a child class of QueryException. + *

+ * For this reason, QueryException must contain all potential state that any of its child classes could ever want to + * push across the wire. Additionally, any catch clauses expecting one of the child Exceptions must know that it is + * running inside of code where the exception has not traveled across the wire. If there is a chance that the + * exception could have been serialized across the wire, the code must catch a QueryException and check the errorCode + * instead. + *

+ * As a corollary, adding new state or adjusting the logic of this class must always be done in a backwards-compatible + * fashion across all child classes of QueryException. + *

+ * If there is any need to do different logic based on the type of error that has happened, the only reliable method + * of discerning the type of the error is to look at the errorCode String. Because these Strings are considered part + * of the API, they are not allowed to change and must maintain their same semantics. The known errorCode Strings + * are pulled together as public static fields on this class in order to make it more clear what options exist. + *

* QueryResource and SqlResource are expected to emit the JSON form of this object when errors happen. */ public class QueryException extends RuntimeException implements SanitizableException { + /** + * Error codes + */ + public static final String JSON_PARSE_ERROR_CODE = "Json parse failed"; + public static final String BAD_QUERY_CONTEXT_ERROR_CODE = "Query context parse failed"; + public static final String QUERY_CAPACITY_EXCEEDED_ERROR_CODE = "Query capacity exceeded"; + public static final String QUERY_INTERRUPTED_ERROR_CODE = "Query interrupted"; + // Note: the proper spelling is with a single "l", but the version with + // two "l"s is documented, we can't change the text of the message. + public static final String QUERY_CANCELED_ERROR_CODE = "Query cancelled"; + public static final String UNAUTHORIZED_ERROR_CODE = "Unauthorized request"; + public static final String UNSUPPORTED_OPERATION_ERROR_CODE = "Unsupported operation"; + public static final String TRUNCATED_RESPONSE_CONTEXT_ERROR_CODE = "Truncated response context"; + public static final String UNKNOWN_EXCEPTION_ERROR_CODE = "Unknown exception"; + public static final String QUERY_TIMEOUT_ERROR_CODE = "Query timeout"; + public static final String QUERY_UNSUPPORTED_ERROR_CODE = "Unsupported query"; + public static final String RESOURCE_LIMIT_EXCEEDED_ERROR_CODE = "Resource limit exceeded"; + public static final String SQL_PARSE_FAILED_ERROR_CODE = "SQL parse failed"; + public static final String PLAN_VALIDATION_FAILED_ERROR_CODE = "Plan validation failed"; + public static final String SQL_QUERY_UNSUPPORTED_ERROR_CODE = "SQL query is unsupported"; + + public enum FailType + { + USER_ERROR(400), + UNAUTHORIZED(401), + CAPACITY_EXCEEDED(429), + UNKNOWN(500), + CANCELED(500), + QUERY_RUNTIME_FAILURE(500), + UNSUPPORTED(501), + TIMEOUT(504); + + private final int expectedStatus; + + FailType(int expectedStatus) + { + this.expectedStatus = expectedStatus; + } + + public int getExpectedStatus() + { + return expectedStatus; + } + } + + public static FailType fromErrorCode(String errorCode) + { + if (errorCode == null) { + return FailType.UNKNOWN; + } + + switch (errorCode) { + case QUERY_CANCELED_ERROR_CODE: + return FailType.CANCELED; + + // These error codes are generally expected to come from a QueryInterruptedException + case QUERY_INTERRUPTED_ERROR_CODE: + case UNSUPPORTED_OPERATION_ERROR_CODE: + case UNKNOWN_EXCEPTION_ERROR_CODE: + case TRUNCATED_RESPONSE_CONTEXT_ERROR_CODE: + return FailType.QUERY_RUNTIME_FAILURE; + case UNAUTHORIZED_ERROR_CODE: + return FailType.UNAUTHORIZED; + + case QUERY_CAPACITY_EXCEEDED_ERROR_CODE: + return FailType.CAPACITY_EXCEEDED; + case QUERY_TIMEOUT_ERROR_CODE: + return FailType.TIMEOUT; + + // These error codes are expected to come from BadQueryExceptions + case JSON_PARSE_ERROR_CODE: + case BAD_QUERY_CONTEXT_ERROR_CODE: + case RESOURCE_LIMIT_EXCEEDED_ERROR_CODE: + // And these ones from the SqlPlanningException which are also BadQueryExceptions + case SQL_PARSE_FAILED_ERROR_CODE: + case PLAN_VALIDATION_FAILED_ERROR_CODE: + case SQL_QUERY_UNSUPPORTED_ERROR_CODE: + return FailType.USER_ERROR; + case QUERY_UNSUPPORTED_ERROR_CODE: + return FailType.UNSUPPORTED; + default: + return FailType.UNKNOWN; + } + } + + /** + * Implementation + */ private final String errorCode; private final String errorClass; private final String host; @@ -48,7 +158,6 @@ public class QueryException extends RuntimeException implements SanitizableExcep this.host = host; } - @VisibleForTesting @JsonCreator public QueryException( @JsonProperty("error") @Nullable String errorCode, @@ -105,4 +214,9 @@ public class QueryException extends RuntimeException implements SanitizableExcep { return new QueryException(errorCode, errorMessageTransformFunction.apply(getMessage()), null, null); } + + public FailType getFailType() + { + return fromErrorCode(errorCode); + } } diff --git a/core/src/main/java/org/apache/druid/query/QueryTimeoutException.java b/core/src/main/java/org/apache/druid/query/QueryTimeoutException.java index d3626e9c5e6..7bd4924000a 100644 --- a/core/src/main/java/org/apache/druid/query/QueryTimeoutException.java +++ b/core/src/main/java/org/apache/druid/query/QueryTimeoutException.java @@ -36,7 +36,6 @@ import javax.annotation.Nullable; public class QueryTimeoutException extends QueryException { private static final String ERROR_CLASS = QueryTimeoutException.class.getName(); - public static final String ERROR_CODE = "Query timeout"; public static final String ERROR_MESSAGE = "Query Timed Out!"; public static final int STATUS_CODE = 504; @@ -53,16 +52,16 @@ public class QueryTimeoutException extends QueryException public QueryTimeoutException() { - super(ERROR_CODE, ERROR_MESSAGE, ERROR_CLASS, resolveHostname()); + super(QUERY_TIMEOUT_ERROR_CODE, ERROR_MESSAGE, ERROR_CLASS, resolveHostname()); } public QueryTimeoutException(String errorMessage) { - super(ERROR_CODE, errorMessage, ERROR_CLASS, resolveHostname()); + super(QUERY_TIMEOUT_ERROR_CODE, errorMessage, ERROR_CLASS, resolveHostname()); } public QueryTimeoutException(String errorMessage, String host) { - super(ERROR_CODE, errorMessage, ERROR_CLASS, host); + super(QUERY_TIMEOUT_ERROR_CODE, errorMessage, ERROR_CLASS, host); } } diff --git a/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java b/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java index 51ff763762c..446e94b9667 100644 --- a/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java +++ b/core/src/test/java/org/apache/druid/query/QueryExceptionTest.java @@ -19,17 +19,12 @@ package org.apache.druid.query; +import org.apache.druid.query.QueryException.FailType; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; -import java.util.function.Function; +import java.util.concurrent.atomic.AtomicLong; -@RunWith(MockitoJUnitRunner.class) public class QueryExceptionTest { private static final String ERROR_CODE = "error code"; @@ -38,36 +33,72 @@ public class QueryExceptionTest private static final String ERROR_MESSAGE_ORIGINAL = "aaaa"; private static final String ERROR_MESSAGE_TRANSFORMED = "bbbb"; - @Mock - private Function trasformFunction; - @Test public void testSanitizeWithTransformFunctionReturningNull() { - Mockito.when(trasformFunction.apply(ArgumentMatchers.eq(ERROR_MESSAGE_ORIGINAL))).thenReturn(null); QueryException queryException = new QueryException(ERROR_CODE, ERROR_MESSAGE_ORIGINAL, ERROR_CLASS, HOST); - QueryException actual = queryException.sanitize(trasformFunction); + + AtomicLong callCount = new AtomicLong(0); + QueryException actual = queryException.sanitize(s -> { + callCount.incrementAndGet(); + Assert.assertEquals(ERROR_MESSAGE_ORIGINAL, s); + return null; + }); + Assert.assertNotNull(actual); Assert.assertEquals(actual.getErrorCode(), ERROR_CODE); Assert.assertNull(actual.getMessage()); Assert.assertNull(actual.getHost()); Assert.assertNull(actual.getErrorClass()); - Mockito.verify(trasformFunction).apply(ArgumentMatchers.eq(ERROR_MESSAGE_ORIGINAL)); - Mockito.verifyNoMoreInteractions(trasformFunction); + Assert.assertEquals(1, callCount.get()); } @Test public void testSanitizeWithTransformFunctionReturningNewString() { - Mockito.when(trasformFunction.apply(ArgumentMatchers.eq(ERROR_MESSAGE_ORIGINAL))).thenReturn(ERROR_MESSAGE_TRANSFORMED); QueryException queryException = new QueryException(ERROR_CODE, ERROR_MESSAGE_ORIGINAL, ERROR_CLASS, HOST); - QueryException actual = queryException.sanitize(trasformFunction); + + AtomicLong callCount = new AtomicLong(0); + QueryException actual = queryException.sanitize(s -> { + callCount.incrementAndGet(); + Assert.assertEquals(ERROR_MESSAGE_ORIGINAL, s); + return ERROR_MESSAGE_TRANSFORMED; + }); + Assert.assertNotNull(actual); Assert.assertEquals(actual.getErrorCode(), ERROR_CODE); Assert.assertEquals(actual.getMessage(), ERROR_MESSAGE_TRANSFORMED); Assert.assertNull(actual.getHost()); Assert.assertNull(actual.getErrorClass()); - Mockito.verify(trasformFunction).apply(ArgumentMatchers.eq(ERROR_MESSAGE_ORIGINAL)); - Mockito.verifyNoMoreInteractions(trasformFunction); + Assert.assertEquals(1, callCount.get()); + } + + @Test + public void testSanity() + { + expectFailTypeForCode(FailType.UNKNOWN, null); + expectFailTypeForCode(FailType.UNKNOWN, "Nobody knows me."); + expectFailTypeForCode(FailType.QUERY_RUNTIME_FAILURE, QueryException.UNKNOWN_EXCEPTION_ERROR_CODE); + expectFailTypeForCode(FailType.USER_ERROR, QueryException.JSON_PARSE_ERROR_CODE); + expectFailTypeForCode(FailType.USER_ERROR, QueryException.BAD_QUERY_CONTEXT_ERROR_CODE); + expectFailTypeForCode(FailType.CAPACITY_EXCEEDED, QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE); + expectFailTypeForCode(FailType.QUERY_RUNTIME_FAILURE, QueryException.QUERY_INTERRUPTED_ERROR_CODE); + expectFailTypeForCode(FailType.CANCELED, QueryException.QUERY_CANCELED_ERROR_CODE); + expectFailTypeForCode(FailType.UNAUTHORIZED, QueryException.UNAUTHORIZED_ERROR_CODE); + expectFailTypeForCode(FailType.QUERY_RUNTIME_FAILURE, QueryException.UNSUPPORTED_OPERATION_ERROR_CODE); + expectFailTypeForCode(FailType.QUERY_RUNTIME_FAILURE, QueryException.TRUNCATED_RESPONSE_CONTEXT_ERROR_CODE); + expectFailTypeForCode(FailType.TIMEOUT, QueryException.QUERY_TIMEOUT_ERROR_CODE); + expectFailTypeForCode(FailType.UNSUPPORTED, QueryException.QUERY_UNSUPPORTED_ERROR_CODE); + expectFailTypeForCode(FailType.USER_ERROR, QueryException.RESOURCE_LIMIT_EXCEEDED_ERROR_CODE); + expectFailTypeForCode(FailType.USER_ERROR, QueryException.SQL_PARSE_FAILED_ERROR_CODE); + expectFailTypeForCode(FailType.USER_ERROR, QueryException.PLAN_VALIDATION_FAILED_ERROR_CODE); + expectFailTypeForCode(FailType.USER_ERROR, QueryException.SQL_QUERY_UNSUPPORTED_ERROR_CODE); + } + + private void expectFailTypeForCode(FailType expected, String code) + { + QueryException exception = new QueryException(new Exception(), code, "java.lang.Exception", "test"); + + Assert.assertEquals(code, expected, exception.getFailType()); } } diff --git a/core/src/test/java/org/apache/druid/query/QueryTimeoutExceptionTest.java b/core/src/test/java/org/apache/druid/query/QueryTimeoutExceptionTest.java index ab187a1fb42..acc6b44d58f 100644 --- a/core/src/test/java/org/apache/druid/query/QueryTimeoutExceptionTest.java +++ b/core/src/test/java/org/apache/druid/query/QueryTimeoutExceptionTest.java @@ -19,7 +19,10 @@ package org.apache.druid.query; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.MapperFeature; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; import org.junit.Assert; import org.junit.Test; @@ -30,7 +33,26 @@ public class QueryTimeoutExceptionTest @Test public void testSerde() throws IOException { - final ObjectMapper mapper = new ObjectMapper(); + // We re-create the configuration from DefaultObjectMapper here because this is in `core` and + // DefaultObjectMapper is in `processing`. Hopefully that distinction disappears at some point + // in time, but it exists today and moving things one way or the other quickly turns into just + // chunking it all together. + final ObjectMapper mapper = new ObjectMapper() + { + { + configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + configure(MapperFeature.AUTO_DETECT_GETTERS, false); + // See https://github.com/FasterXML/jackson-databind/issues/170 + // configure(MapperFeature.AUTO_DETECT_CREATORS, false); + configure(MapperFeature.AUTO_DETECT_FIELDS, false); + configure(MapperFeature.AUTO_DETECT_IS_GETTERS, false); + configure(MapperFeature.AUTO_DETECT_SETTERS, false); + configure(MapperFeature.ALLOW_FINAL_FIELDS_AS_MUTATORS, false); + configure(SerializationFeature.INDENT_OUTPUT, false); + configure(SerializationFeature.FLUSH_AFTER_WRITE_VALUE, false); + } + }; + QueryTimeoutException timeoutException = mapper.readValue( mapper.writeValueAsBytes(new QueryTimeoutException()), QueryTimeoutException.class diff --git a/integration-tests/src/test/java/org/apache/druid/tests/query/ITSqlCancelTest.java b/integration-tests/src/test/java/org/apache/druid/tests/query/ITSqlCancelTest.java index ad8dd3cb11f..d5bc1e30204 100644 --- a/integration-tests/src/test/java/org/apache/druid/tests/query/ITSqlCancelTest.java +++ b/integration-tests/src/test/java/org/apache/druid/tests/query/ITSqlCancelTest.java @@ -27,7 +27,6 @@ import org.apache.druid.java.util.common.RE; import org.apache.druid.java.util.http.client.response.StatusResponseHolder; import org.apache.druid.query.BaseQuery; import org.apache.druid.query.QueryException; -import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.sql.http.SqlQuery; import org.apache.druid.testing.IntegrationTestingConfig; import org.apache.druid.testing.clients.SqlResourceTestClient; @@ -107,10 +106,10 @@ public class ITSqlCancelTest throw new ISE("Query is not canceled after cancel request"); } QueryException queryException = jsonMapper.readValue(queryResponse.getContent(), QueryException.class); - if (!QueryInterruptedException.QUERY_CANCELED.equals(queryException.getErrorCode())) { + if (!"Query cancelled".equals(queryException.getErrorCode())) { throw new ISE( "Expected error code [%s], actual [%s]", - QueryInterruptedException.QUERY_CANCELED, + "Query cancelled", queryException.getErrorCode() ); } @@ -138,7 +137,11 @@ public class ITSqlCancelTest final StatusResponseHolder queryResponse = queryResponseFuture.get(30, TimeUnit.SECONDS); if (!queryResponse.getStatus().equals(HttpResponseStatus.OK)) { - throw new ISE("Cancel request failed with status[%s] and content[%s]", queryResponse.getStatus(), queryResponse.getContent()); + throw new ISE( + "Cancel request failed with status[%s] and content[%s]", + queryResponse.getStatus(), + queryResponse.getContent() + ); } } } diff --git a/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java b/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java index 8be18edf18a..47ba3c90777 100644 --- a/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java +++ b/processing/src/main/java/org/apache/druid/query/BadJsonQueryException.java @@ -25,12 +25,11 @@ import com.fasterxml.jackson.core.JsonParseException; public class BadJsonQueryException extends BadQueryException { - public static final String ERROR_CODE = "Json parse failed"; public static final String ERROR_CLASS = JsonParseException.class.getName(); public BadJsonQueryException(JsonParseException e) { - this(ERROR_CODE, e.getMessage(), ERROR_CLASS); + this(JSON_PARSE_ERROR_CODE, e.getMessage(), ERROR_CLASS); } @JsonCreator diff --git a/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java b/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java index 29f63b1f40e..cbfb0ca410c 100644 --- a/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java +++ b/processing/src/main/java/org/apache/druid/query/BadQueryContextException.java @@ -24,17 +24,11 @@ import com.fasterxml.jackson.annotation.JsonProperty; public class BadQueryContextException extends BadQueryException { - public static final String ERROR_CODE = "Query context parse failed"; public static final String ERROR_CLASS = BadQueryContextException.class.getName(); - public BadQueryContextException(Exception e) - { - this(ERROR_CODE, e.getMessage(), ERROR_CLASS); - } - public BadQueryContextException(String msg) { - this(ERROR_CODE, msg, ERROR_CLASS); + this(BAD_QUERY_CONTEXT_ERROR_CODE, msg, ERROR_CLASS); } @JsonCreator diff --git a/processing/src/main/java/org/apache/druid/query/QueryCapacityExceededException.java b/processing/src/main/java/org/apache/druid/query/QueryCapacityExceededException.java index f62eb9166d8..694fbb780cb 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryCapacityExceededException.java +++ b/processing/src/main/java/org/apache/druid/query/QueryCapacityExceededException.java @@ -32,7 +32,7 @@ import org.apache.druid.java.util.common.StringUtils; *

  • When the query is rejected by QueryScheduler.
  • *
  • When the query cannot acquire enough merge buffers for groupBy v2
  • * - * + *

    * As a {@link QueryException} it is expected to be serialied to a json response, but will be mapped to * {@link #STATUS_CODE} instead of the default HTTP 500 status. */ @@ -43,17 +43,16 @@ public class QueryCapacityExceededException extends QueryException private static final String LANE_ERROR_MESSAGE_TEMPLATE = "Too many concurrent queries for lane '%s', query capacity of %s exceeded. Please try your query again later."; private static final String ERROR_CLASS = QueryCapacityExceededException.class.getName(); - public static final String ERROR_CODE = "Query capacity exceeded"; public static final int STATUS_CODE = 429; public QueryCapacityExceededException(int capacity) { - super(ERROR_CODE, makeTotalErrorMessage(capacity), ERROR_CLASS, null); + super(QUERY_CAPACITY_EXCEEDED_ERROR_CODE, makeTotalErrorMessage(capacity), ERROR_CLASS, null); } public QueryCapacityExceededException(String lane, int capacity) { - super(ERROR_CODE, makeLaneErrorMessage(lane, capacity), ERROR_CLASS, null); + super(QUERY_CAPACITY_EXCEEDED_ERROR_CODE, makeLaneErrorMessage(lane, capacity), ERROR_CLASS, null); } /** @@ -62,7 +61,12 @@ public class QueryCapacityExceededException extends QueryException */ public static QueryCapacityExceededException withErrorMessageAndResolvedHost(String errorMessage) { - return new QueryCapacityExceededException(ERROR_CODE, errorMessage, ERROR_CLASS, resolveHostname()); + return new QueryCapacityExceededException( + QUERY_CAPACITY_EXCEEDED_ERROR_CODE, + errorMessage, + ERROR_CLASS, + resolveHostname() + ); } @JsonCreator diff --git a/processing/src/main/java/org/apache/druid/query/QueryInterruptedException.java b/processing/src/main/java/org/apache/druid/query/QueryInterruptedException.java index ae67039242f..91760aa7c10 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryInterruptedException.java +++ b/processing/src/main/java/org/apache/druid/query/QueryInterruptedException.java @@ -29,7 +29,7 @@ import java.util.concurrent.CancellationException; /** * Exception representing a failed query. The name "QueryInterruptedException" is a misnomer; this is actually * used on the client side for *all* kinds of failed queries. - * + *

    * Fields: * - "errorCode" is a well-defined errorCode code taken from a specific list (see the static constants). "Unknown exception" * represents all wrapped exceptions other than interrupt, cancellation, resource limit exceeded, unauthorized @@ -37,21 +37,12 @@ import java.util.concurrent.CancellationException; * - "errorMessage" is the toString of the wrapped exception * - "errorClass" is the class of the wrapped exception * - "host" is the host that the errorCode occurred on - * + *

    * The QueryResource is expected to emit the JSON form of this object when errors happen, and the DirectDruidClient * deserializes and wraps them. */ public class QueryInterruptedException extends QueryException { - public static final String QUERY_INTERRUPTED = "Query interrupted"; - // Note: the proper spelling is with a single "l", but the version with - // two "l"s is documented, we can't change the text of the message. - public static final String QUERY_CANCELED = "Query cancelled"; - public static final String UNAUTHORIZED = "Unauthorized request"; - public static final String UNSUPPORTED_OPERATION = "Unsupported operation"; - public static final String TRUNCATED_RESPONSE_CONTEXT = "Truncated response context"; - public static final String UNKNOWN_EXCEPTION = "Unknown exception"; - @JsonCreator public QueryInterruptedException( @JsonProperty("error") @Nullable String errorCode, @@ -96,15 +87,15 @@ public class QueryInterruptedException extends QueryException if (e instanceof QueryInterruptedException) { return ((QueryInterruptedException) e).getErrorCode(); } else if (e instanceof InterruptedException) { - return QUERY_INTERRUPTED; + return QUERY_INTERRUPTED_ERROR_CODE; } else if (e instanceof CancellationException) { - return QUERY_CANCELED; + return QUERY_CANCELED_ERROR_CODE; } else if (e instanceof UnsupportedOperationException) { - return UNSUPPORTED_OPERATION; + return UNSUPPORTED_OPERATION_ERROR_CODE; } else if (e instanceof TruncatedResponseContextException) { - return TRUNCATED_RESPONSE_CONTEXT; + return TRUNCATED_RESPONSE_CONTEXT_ERROR_CODE; } else { - return UNKNOWN_EXCEPTION; + return UNKNOWN_EXCEPTION_ERROR_CODE; } } diff --git a/processing/src/main/java/org/apache/druid/query/QueryUnsupportedException.java b/processing/src/main/java/org/apache/druid/query/QueryUnsupportedException.java index bde1f9d14e1..81d82a94871 100644 --- a/processing/src/main/java/org/apache/druid/query/QueryUnsupportedException.java +++ b/processing/src/main/java/org/apache/druid/query/QueryUnsupportedException.java @@ -29,14 +29,13 @@ import javax.annotation.Nullable; * This exception is for the query engine to surface when a query cannot be run. This can be due to the * following reasons: 1) The query is not supported yet. 2) The query is not something Druid would ever supports. * For these cases, the exact causes and details should also be documented in Druid user facing documents. - * + *

    * As a {@link QueryException} it is expected to be serialized to a json response with a proper HTTP error code * ({@link #STATUS_CODE}). */ public class QueryUnsupportedException extends QueryException { private static final String ERROR_CLASS = QueryUnsupportedException.class.getName(); - public static final String ERROR_CODE = "Unsupported query"; public static final int STATUS_CODE = 501; @JsonCreator @@ -52,6 +51,6 @@ public class QueryUnsupportedException extends QueryException public QueryUnsupportedException(String errorMessage) { - super(ERROR_CODE, errorMessage, ERROR_CLASS, resolveHostname()); + super(QUERY_UNSUPPORTED_ERROR_CODE, errorMessage, ERROR_CLASS, resolveHostname()); } } diff --git a/processing/src/main/java/org/apache/druid/query/ResourceLimitExceededException.java b/processing/src/main/java/org/apache/druid/query/ResourceLimitExceededException.java index 169d774cabf..a41699cf442 100644 --- a/processing/src/main/java/org/apache/druid/query/ResourceLimitExceededException.java +++ b/processing/src/main/java/org/apache/druid/query/ResourceLimitExceededException.java @@ -25,7 +25,7 @@ import org.apache.druid.java.util.common.StringUtils; /** * Exception indicating that an operation failed because it exceeded some configured resource limit. - * + *

    * This is a {@link BadQueryException} because it likely indicates a user's misbehavior when this exception is thrown. * The resource limitations set by Druid cluster operators are typically less flexible than the parameters of * a user query, so when a user query requires too many resources, the likely remedy is that the user query @@ -33,8 +33,6 @@ import org.apache.druid.java.util.common.StringUtils; */ public class ResourceLimitExceededException extends BadQueryException { - public static final String ERROR_CODE = "Resource limit exceeded"; - public static ResourceLimitExceededException withMessage(String message, Object... arguments) { return new ResourceLimitExceededException(StringUtils.nonStrictFormat(message, arguments)); @@ -47,7 +45,7 @@ public class ResourceLimitExceededException extends BadQueryException public ResourceLimitExceededException(String message) { - this(ERROR_CODE, message, ResourceLimitExceededException.class.getName()); + this(RESOURCE_LIMIT_EXCEEDED_ERROR_CODE, message, ResourceLimitExceededException.class.getName()); } @JsonCreator diff --git a/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java b/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java index a943297d173..6727782cc40 100644 --- a/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java +++ b/processing/src/main/java/org/apache/druid/query/context/ResponseContext.java @@ -80,18 +80,18 @@ import java.util.stream.Collectors; *

  • Manages headers size by dropping fields when the header would get too * large.
  • * - * + *

    * A result is that the information the context, when inspected by a calling * query, may be incomplete if some of it was previously dropped by the * called query. * *

    API

    - * + *

    * The query profile needs to obtain the full, untruncated information. To do this * it piggy-backs on the set operations to obtain the full value. To ensure this * is possible, code that works with standard values should call the set (or add) * functions provided which will do the needed map update. - */ + */ @PublicApi public abstract class ResponseContext { @@ -118,7 +118,7 @@ public abstract class ResponseContext /** * Merges two values of type T. - * + *

    * This method may modify "oldValue" but must not modify "newValue". */ Object mergeValues(Object oldValue, Object newValue); @@ -317,7 +317,8 @@ public abstract class ResponseContext true, true, new TypeReference>() { - }) + } + ) { @Override @SuppressWarnings("unchecked") @@ -334,14 +335,15 @@ public abstract class ResponseContext */ public static final Key UNCOVERED_INTERVALS_OVERFLOWED = new BooleanKey( "uncoveredIntervalsOverflowed", - true); + true + ); /** * Map of most relevant query ID to remaining number of responses from query nodes. * The value is initialized in {@code CachingClusteredClient} when it initializes the connection to the query nodes, * and is updated whenever they respond (@code DirectDruidClient). {@code RetryQueryRunner} uses this value to * check if the {@link #MISSING_SEGMENTS} is valid. - * + *

    * Currently, the broker doesn't run subqueries in parallel, the remaining number of responses will be updated * one by one per subquery. However, since it can be parallelized to run subqueries simultaneously, we store them * in a ConcurrentHashMap. @@ -351,7 +353,8 @@ public abstract class ResponseContext public static final Key REMAINING_RESPONSES_FROM_QUERY_SERVERS = new AbstractKey( "remainingResponsesFromQueryServers", false, true, - Object.class) + Object.class + ) { @Override @SuppressWarnings("unchecked") @@ -361,7 +364,8 @@ public abstract class ResponseContext final NonnullPair pair = (NonnullPair) idAndNumResponses; map.compute( pair.lhs, - (id, remaining) -> remaining == null ? pair.rhs : remaining + pair.rhs); + (id, remaining) -> remaining == null ? pair.rhs : remaining + pair.rhs + ); return map; } }; @@ -372,7 +376,10 @@ public abstract class ResponseContext public static final Key MISSING_SEGMENTS = new AbstractKey( "missingSegments", true, true, - new TypeReference>() {}) + new TypeReference>() + { + } + ) { @Override @SuppressWarnings("unchecked") @@ -396,7 +403,10 @@ public abstract class ResponseContext public static final Key QUERY_TOTAL_BYTES_GATHERED = new AbstractKey( "queryTotalBytesGathered", false, false, - new TypeReference() {}) + new TypeReference() + { + } + ) { @Override public Object mergeValues(Object oldValue, Object newValue) @@ -410,7 +420,8 @@ public abstract class ResponseContext */ public static final Key QUERY_FAIL_DEADLINE_MILLIS = new LongKey( "queryFailTime", - false); + false + ); /** * This variable indicates when a running query should be expired, @@ -418,17 +429,19 @@ public abstract class ResponseContext */ public static final Key TIMEOUT_AT = new LongKey( "timeoutAt", - false); + false + ); /** * The number of rows scanned by {@link org.apache.druid.query.scan.ScanQueryEngine}. - * + *

    * Named "count" for backwards compatibility with older data servers that still send this, even though it's now * marked as internal. */ public static final Key NUM_SCANNED_ROWS = new CounterKey( "count", - false); + false + ); /** * The total CPU time for threads related to Sequence processing of the query. @@ -437,14 +450,16 @@ public abstract class ResponseContext */ public static final Key CPU_CONSUMED_NANOS = new CounterKey( "cpuConsumed", - false); + false + ); /** * Indicates if a {@link ResponseContext} was truncated during serialization. */ public static final Key TRUNCATED = new BooleanKey( "truncated", - false); + false + ); /** * One and only global list of keys. This is a semi-constant: it is mutable @@ -461,20 +476,21 @@ public abstract class ResponseContext private final ConcurrentMap registeredKeys = new ConcurrentSkipListMap<>(); static { - instance().registerKeys(new Key[] - { - UNCOVERED_INTERVALS, - UNCOVERED_INTERVALS_OVERFLOWED, - REMAINING_RESPONSES_FROM_QUERY_SERVERS, - MISSING_SEGMENTS, - ETAG, - QUERY_TOTAL_BYTES_GATHERED, - QUERY_FAIL_DEADLINE_MILLIS, - TIMEOUT_AT, - NUM_SCANNED_ROWS, - CPU_CONSUMED_NANOS, - TRUNCATED, - }); + instance().registerKeys( + new Key[]{ + UNCOVERED_INTERVALS, + UNCOVERED_INTERVALS_OVERFLOWED, + REMAINING_RESPONSES_FROM_QUERY_SERVERS, + MISSING_SEGMENTS, + ETAG, + QUERY_TOTAL_BYTES_GATHERED, + QUERY_FAIL_DEADLINE_MILLIS, + TIMEOUT_AT, + NUM_SCANNED_ROWS, + CPU_CONSUMED_NANOS, + TRUNCATED, + } + ); } /** @@ -701,8 +717,10 @@ public abstract class ResponseContext public void addRemainingResponse(String id, int count) { - addValue(Keys.REMAINING_RESPONSES_FROM_QUERY_SERVERS, - new NonnullPair<>(id, count)); + addValue( + Keys.REMAINING_RESPONSES_FROM_QUERY_SERVERS, + new NonnullPair<>(id, count) + ); } public void addMissingSegments(List descriptors) @@ -820,7 +838,6 @@ public abstract class ResponseContext * * @param node {@link ArrayNode} which elements are being removed. * @param target the number of chars need to be removed. - * * @return the number of removed chars. */ private static int removeNodeElementsToSatisfyCharsLimit(ArrayNode node, int target) @@ -851,7 +868,7 @@ public abstract class ResponseContext private final String truncatedResult; private final String fullResult; - SerializationResult(@Nullable String truncatedResult, String fullResult) + public SerializationResult(@Nullable String truncatedResult, String fullResult) { this.truncatedResult = truncatedResult; this.fullResult = fullResult; diff --git a/server/src/main/java/org/apache/druid/client/JsonParserIterator.java b/server/src/main/java/org/apache/druid/client/JsonParserIterator.java index 97c772ed192..e14e24c2231 100644 --- a/server/src/main/java/org/apache/druid/client/JsonParserIterator.java +++ b/server/src/main/java/org/apache/druid/client/JsonParserIterator.java @@ -168,7 +168,11 @@ public class JsonParserIterator implements Iterator, Closeable } else if (checkTimeout()) { throw timeoutQuery(); } else { - // TODO: NettyHttpClient should check the actual cause of the failure and set it in the future properly. + // The InputStream is null and we have not timed out, there might be multiple reasons why we could hit this + // condition, guess that we are hitting it because of scatter-gather bytes. It would be better to be more + // explicit about why errors are happening than guessing, but this comment is being rewritten from a T-O-D-O, + // so the intent is just to document this better rather than do all of the logic to fix it. If/when we get + // this exception thrown for other reasons, it would be great to document what other reasons this can happen. throw ResourceLimitExceededException.withMessage( "Possibly max scatter-gather bytes limit reached while reading from url[%s].", url @@ -207,11 +211,11 @@ public class JsonParserIterator implements Iterator, Closeable /** * Converts the given exception to a proper type of {@link QueryException}. * The use cases of this method are: - * + *

    * - All non-QueryExceptions are wrapped with {@link QueryInterruptedException}. * - The QueryException from {@link DirectDruidClient} is converted to a more specific type of QueryException - * based on {@link QueryException#getErrorCode()}. During conversion, {@link QueryException#host} is overridden - * by {@link #host}. + * based on {@link QueryException#getErrorCode()}. During conversion, {@link QueryException#host} is overridden + * by {@link #host}. */ private QueryException convertException(Throwable cause) { @@ -219,9 +223,9 @@ public class JsonParserIterator implements Iterator, Closeable if (cause instanceof QueryException) { final QueryException queryException = (QueryException) cause; if (queryException.getErrorCode() == null) { - // errorCode should not be null now, but maybe could be null in the past.. + // errorCode should not be null now, but maybe could be null in the past... return new QueryInterruptedException( - queryException.getErrorCode(), + QueryException.UNKNOWN_EXCEPTION_ERROR_CODE, queryException.getMessage(), queryException.getErrorClass(), host @@ -229,32 +233,37 @@ public class JsonParserIterator implements Iterator, Closeable } // Note: this switch clause is to restore the 'type' information of QueryExceptions which is lost during - // JSON serialization. This is not a good way to restore the correct exception type. Rather, QueryException - // should store its type when it is serialized, so that we can know the exact type when it is deserialized. + // JSON serialization. As documented on the QueryException class, the errorCode of QueryException is the only + // way to differentiate the cause of the exception. This code does not cover all possible exceptions that + // could come up and so, likely, doesn't produce exceptions reliably. The only safe way to catch and interact + // with a QueryException is to catch QueryException and check its errorCode. In some future code change, we + // should likely remove this switch entirely, but when we do that, we need to make sure to also adjust any + // points in the code that are catching the specific child Exceptions to instead catch QueryException and + // check the errorCode. switch (queryException.getErrorCode()) { // The below is the list of exceptions that can be thrown in historicals and propagated to the broker. - case QueryTimeoutException.ERROR_CODE: + case QueryException.QUERY_TIMEOUT_ERROR_CODE: return new QueryTimeoutException( queryException.getErrorCode(), queryException.getMessage(), queryException.getErrorClass(), host ); - case QueryCapacityExceededException.ERROR_CODE: + case QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE: return new QueryCapacityExceededException( queryException.getErrorCode(), queryException.getMessage(), queryException.getErrorClass(), host ); - case QueryUnsupportedException.ERROR_CODE: + case QueryException.QUERY_UNSUPPORTED_ERROR_CODE: return new QueryUnsupportedException( queryException.getErrorCode(), queryException.getMessage(), queryException.getErrorClass(), host ); - case ResourceLimitExceededException.ERROR_CODE: + case QueryException.RESOURCE_LIMIT_EXCEEDED_ERROR_CODE: return new ResourceLimitExceededException( queryException.getErrorCode(), queryException.getMessage(), diff --git a/server/src/main/java/org/apache/druid/server/QueryResource.java b/server/src/main/java/org/apache/druid/server/QueryResource.java index 743ca9e60ba..84e6acf24c1 100644 --- a/server/src/main/java/org/apache/druid/server/QueryResource.java +++ b/server/src/main/java/org/apache/druid/server/QueryResource.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.SequenceWriter; import com.fasterxml.jackson.databind.module.SimpleModule; import com.fasterxml.jackson.datatype.joda.ser.DateTimeSerializer; import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; @@ -30,7 +31,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.Iterables; -import com.google.common.io.CountingOutputStream; import com.google.inject.Inject; import org.apache.druid.client.DirectDruidClient; import org.apache.druid.guice.LazySingleton; @@ -38,20 +38,13 @@ import org.apache.druid.guice.annotations.Json; import org.apache.druid.guice.annotations.Self; import org.apache.druid.guice.annotations.Smile; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.guava.Sequence; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.emitter.EmittingLogger; import org.apache.druid.query.BadJsonQueryException; -import org.apache.druid.query.BadQueryException; import org.apache.druid.query.Query; -import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; -import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryToolChest; -import org.apache.druid.query.QueryUnsupportedException; import org.apache.druid.query.TruncatedResponseContextException; import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.context.ResponseContext.Keys; @@ -64,7 +57,9 @@ import org.apache.druid.server.security.ForbiddenException; import org.joda.time.DateTime; import javax.annotation.Nullable; +import javax.servlet.AsyncContext; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.POST; @@ -72,12 +67,10 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; -import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; -import javax.ws.rs.core.StreamingOutput; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -90,6 +83,8 @@ import java.util.concurrent.atomic.AtomicLong; public class QueryResource implements QueryCountStatsProvider { protected static final EmittingLogger log = new EmittingLogger(QueryResource.class); + public static final EmittingLogger NO_STACK_LOGGER = log.noStackTrace(); + @Deprecated // use SmileMediaTypes.APPLICATION_JACKSON_SMILE protected static final String APPLICATION_SMILE = "application/smile"; @@ -116,6 +111,7 @@ public class QueryResource implements QueryCountStatsProvider private final AtomicLong failedQueryCount = new AtomicLong(); private final AtomicLong interruptedQueryCount = new AtomicLong(); private final AtomicLong timedOutQueryCount = new AtomicLong(); + private final QueryResourceQueryMetricCounter counter = new QueryResourceQueryMetricCounter(); @Inject public QueryResource( @@ -171,23 +167,28 @@ public class QueryResource implements QueryCountStatsProvider @POST @Produces({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE, APPLICATION_SMILE}) + @Nullable public Response doPost( final InputStream in, @QueryParam("pretty") final String pretty, - - // used to get request content-type,Accept header, remote address and auth-related headers @Context final HttpServletRequest req ) throws IOException { final QueryLifecycle queryLifecycle = queryLifecycleFactory.factorize(); - final ResourceIOReaderWriter ioReaderWriter = createResourceIOReaderWriter(req, pretty != null); + final ResourceIOReaderWriter io = createResourceIOReaderWriter(req, pretty != null); final String currThreadName = Thread.currentThread().getName(); try { - final Query query = readQuery(req, in, ioReaderWriter); + final Query query; + try { + query = readQuery(req, in, io); + } + catch (QueryException e) { + return io.getResponseWriter().buildNonOkResponse(e.getFailType().getExpectedStatus(), e); + } + queryLifecycle.initialize(query); - final String queryId = queryLifecycle.getQueryId(); final String queryThreadName = queryLifecycle.threadName(currThreadName); Thread.currentThread().setName(queryThreadName); @@ -195,137 +196,88 @@ public class QueryResource implements QueryCountStatsProvider log.debug("Got query [%s]", queryLifecycle.getQuery()); } - final Access authResult = queryLifecycle.authorize(req); + final Access authResult; + try { + authResult = queryLifecycle.authorize(req); + } + catch (RuntimeException e) { + final QueryException qe; + + if (e instanceof QueryException) { + qe = (QueryException) e; + } else { + qe = new QueryInterruptedException(e); + } + + return io.getResponseWriter().buildNonOkResponse(qe.getFailType().getExpectedStatus(), qe); + } + if (!authResult.isAllowed()) { throw new ForbiddenException(authResult.toString()); } - final QueryResponse queryResponse = queryLifecycle.execute(); - final Sequence results = queryResponse.getResults(); - final ResponseContext responseContext = queryResponse.getResponseContext(); - final String prevEtag = getPreviousEtag(req); - - if (prevEtag != null && prevEtag.equals(responseContext.getEntityTag())) { - queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1); - successfulQueryCount.incrementAndGet(); - return Response.notModified().build(); - } - - final Yielder yielder = Yielders.each(results); + // We use an async context not because we are actually going to run this async, but because we want to delay + // the decision of what the response code should be until we have gotten the first few data points to return. + // Returning a Response object from this point forward requires that object to know the status code, which we + // don't actually know until we are in the accumulator, but if we try to return a Response object from the + // accumulator, we cannot properly stream results back, because the accumulator won't release control of the + // Response until it has consumed the underlying Sequence. + final AsyncContext asyncContext = req.startAsync(); try { - final ObjectWriter jsonWriter = queryLifecycle.newOutputWriter(ioReaderWriter); - - Response.ResponseBuilder responseBuilder = Response - .ok( - new StreamingOutput() - { - @Override - public void write(OutputStream outputStream) throws WebApplicationException - { - Exception e = null; - - CountingOutputStream os = new CountingOutputStream(outputStream); - try { - // json serializer will always close the yielder - jsonWriter.writeValue(os, yielder); - - os.flush(); // Some types of OutputStream suppress flush errors in the .close() method. - os.close(); - } - catch (Exception ex) { - e = ex; - log.noStackTrace().error(ex, "Unable to send query response."); - throw new RuntimeException(ex); - } - finally { - Thread.currentThread().setName(currThreadName); - - queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), os.getCount()); - - if (e == null) { - successfulQueryCount.incrementAndGet(); - } else { - failedQueryCount.incrementAndGet(); - } - } - } - }, - ioReaderWriter.getResponseWriter().getResponseType() - ) - .header(QUERY_ID_RESPONSE_HEADER, queryId); - - attachResponseContextToHttpResponse(queryId, responseContext, responseBuilder, jsonMapper, - responseContextConfig, selfNode - ); - - return responseBuilder.build(); - } - catch (QueryException e) { - // make sure to close yielder if anything happened before starting to serialize the response. - yielder.close(); - throw e; - } - catch (Exception e) { - // make sure to close yielder if anything happened before starting to serialize the response. - yielder.close(); - throw new RuntimeException(e); + new QueryResourceQueryResultPusher(req, queryLifecycle, io, (HttpServletResponse) asyncContext.getResponse()) + .push(); } finally { - // do not close yielder here, since we do not want to close the yielder prior to - // StreamingOutput having iterated over all the results + asyncContext.complete(); } } - catch (QueryInterruptedException e) { - interruptedQueryCount.incrementAndGet(); - queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), -1); - return ioReaderWriter.getResponseWriter().gotError(e); - } - catch (QueryTimeoutException timeout) { - timedOutQueryCount.incrementAndGet(); - queryLifecycle.emitLogsAndMetrics(timeout, req.getRemoteAddr(), -1); - return ioReaderWriter.getResponseWriter().gotTimeout(timeout); - } - catch (QueryCapacityExceededException cap) { - failedQueryCount.incrementAndGet(); - queryLifecycle.emitLogsAndMetrics(cap, req.getRemoteAddr(), -1); - return ioReaderWriter.getResponseWriter().gotLimited(cap); - } - catch (QueryUnsupportedException unsupported) { - failedQueryCount.incrementAndGet(); - queryLifecycle.emitLogsAndMetrics(unsupported, req.getRemoteAddr(), -1); - return ioReaderWriter.getResponseWriter().gotUnsupported(unsupported); - } - catch (BadQueryException e) { - interruptedQueryCount.incrementAndGet(); - queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), -1); - return ioReaderWriter.getResponseWriter().gotBadQuery(e); - } - catch (ForbiddenException e) { - // don't do anything for an authorization failure, ForbiddenExceptionMapper will catch this later and - // send an error response if this is thrown. - throw e; - } catch (Exception e) { - failedQueryCount.incrementAndGet(); - queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), -1); + if (e instanceof ForbiddenException && !req.isAsyncStarted()) { + // We can only pass through the Forbidden exception if we haven't started async yet. + throw e; + } + log.warn(e, "Uncaught exception from query processing. This should be caught and handled directly."); - log.noStackTrace() - .makeAlert(e, "Exception handling request") - .addData( - "query", - queryLifecycle.getQuery() != null - ? jsonMapper.writeValueAsString(queryLifecycle.getQuery()) - : "unparseable query" - ) - .addData("peer", req.getRemoteAddr()) - .emit(); - - return ioReaderWriter.getResponseWriter().gotError(e); + // Just fall back to the async context. + AsyncContext asyncContext = req.startAsync(); + try { + final HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + // If the response is committed, we actually processed and started doing things with the request, + // so the best we can do is just complete in the finally and hope for the best. + if (!response.isCommitted()) { + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + response.setContentType(MediaType.APPLICATION_JSON); + try (OutputStream out = response.getOutputStream()) { + final QueryException responseException = new QueryException( + QueryException.UNKNOWN_EXCEPTION_ERROR_CODE, + "Unhandled exception made it to the top", + e.getClass().getName(), + req.getRemoteHost() + ); + out.write(jsonMapper.writeValueAsBytes(responseException)); + } + } + } + finally { + asyncContext.complete(); + } } finally { Thread.currentThread().setName(currThreadName); } + return null; + } + + public interface QueryMetricCounter + { + void incrementSuccess(); + + void incrementFailed(); + + void incrementInterrupted(); + + void incrementTimedOut(); } public static void attachResponseContextToHttpResponse( @@ -416,16 +368,20 @@ public class QueryResource implements QueryCountStatsProvider // response type defaults to Content-Type if 'Accept' header not provided String responseType = Strings.isNullOrEmpty(acceptHeader) ? requestType : acceptHeader; - boolean isRequestSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(requestType) || APPLICATION_SMILE.equals(requestType); - boolean isResponseSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(responseType) || APPLICATION_SMILE.equals(responseType); + boolean isRequestSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(requestType) || APPLICATION_SMILE.equals( + requestType); + boolean isResponseSmile = SmileMediaTypes.APPLICATION_JACKSON_SMILE.equals(responseType) + || APPLICATION_SMILE.equals(responseType); return new ResourceIOReaderWriter( isRequestSmile ? smileMapper : jsonMapper, - new ResourceIOWriter(isResponseSmile ? SmileMediaTypes.APPLICATION_JACKSON_SMILE : MediaType.APPLICATION_JSON, - isResponseSmile ? smileMapper : jsonMapper, - isResponseSmile ? serializeDateTimeAsLongSmileMapper : serializeDateTimeAsLongJsonMapper, - pretty - )); + new ResourceIOWriter( + isResponseSmile ? SmileMediaTypes.APPLICATION_JACKSON_SMILE : MediaType.APPLICATION_JSON, + isResponseSmile ? smileMapper : jsonMapper, + isResponseSmile ? serializeDateTimeAsLongSmileMapper : serializeDateTimeAsLongJsonMapper, + pretty + ) + ); } protected static class ResourceIOReaderWriter @@ -504,26 +460,6 @@ public class QueryResource implements QueryCountStatsProvider ); } - Response gotTimeout(QueryTimeoutException e) throws IOException - { - return buildNonOkResponse(QueryTimeoutException.STATUS_CODE, e); - } - - Response gotLimited(QueryCapacityExceededException e) throws IOException - { - return buildNonOkResponse(QueryCapacityExceededException.STATUS_CODE, e); - } - - Response gotUnsupported(QueryUnsupportedException e) throws IOException - { - return buildNonOkResponse(QueryUnsupportedException.STATUS_CODE, e); - } - - Response gotBadQuery(BadQueryException e) throws IOException - { - return buildNonOkResponse(BadQueryException.STATUS_CODE, e); - } - Response buildNonOkResponse(int status, Exception e) throws JsonProcessingException { return Response.status(status) @@ -565,4 +501,142 @@ public class QueryResource implements QueryCountStatsProvider builder.header(HEADER_ETAG, entityTag); } } + + private class QueryResourceQueryMetricCounter implements QueryMetricCounter + { + @Override + public void incrementSuccess() + { + successfulQueryCount.incrementAndGet(); + } + + @Override + public void incrementFailed() + { + failedQueryCount.incrementAndGet(); + } + + @Override + public void incrementInterrupted() + { + interruptedQueryCount.incrementAndGet(); + } + + @Override + public void incrementTimedOut() + { + timedOutQueryCount.incrementAndGet(); + } + } + + private class QueryResourceQueryResultPusher extends QueryResultPusher + { + private final HttpServletRequest req; + private final QueryLifecycle queryLifecycle; + private final ResourceIOReaderWriter io; + + public QueryResourceQueryResultPusher( + HttpServletRequest req, + QueryLifecycle queryLifecycle, + ResourceIOReaderWriter io, + HttpServletResponse response + ) + { + super( + response, + QueryResource.this.jsonMapper, + QueryResource.this.responseContextConfig, + QueryResource.this.selfNode, + QueryResource.this.counter, + queryLifecycle.getQueryId(), + MediaType.valueOf(io.getResponseWriter().getResponseType()) + ); + this.req = req; + this.queryLifecycle = queryLifecycle; + this.io = io; + } + + @Override + public ResultsWriter start() + { + return new ResultsWriter() + { + @Override + public QueryResponse start(HttpServletResponse response) + { + final QueryResponse queryResponse = queryLifecycle.execute(); + final ResponseContext responseContext = queryResponse.getResponseContext(); + final String prevEtag = getPreviousEtag(req); + + if (prevEtag != null && prevEtag.equals(responseContext.getEntityTag())) { + queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), -1); + counter.incrementSuccess(); + response.setStatus(HttpServletResponse.SC_NOT_MODIFIED); + return null; + } + + return queryResponse; + } + + @Override + public Writer makeWriter(OutputStream out) throws IOException + { + final ObjectWriter objectWriter = queryLifecycle.newOutputWriter(io); + final SequenceWriter sequenceWriter = objectWriter.writeValuesAsArray(out); + return new Writer() + { + + @Override + public void writeResponseStart() + { + // Do nothing + } + + @Override + public void writeRow(Object obj) throws IOException + { + sequenceWriter.write(obj); + } + + @Override + public void writeResponseEnd() + { + // Do nothing + } + + @Override + public void close() throws IOException + { + sequenceWriter.close(); + } + }; + } + + @Override + public void recordSuccess(long numBytes) + { + queryLifecycle.emitLogsAndMetrics(null, req.getRemoteAddr(), numBytes); + } + + @Override + public void recordFailure(Exception e) + { + queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), -1); + } + + @Override + public void close() + { + + } + }; + } + + @Override + public void writeException(Exception e, OutputStream out) throws IOException + { + final ObjectWriter objectWriter = queryLifecycle.newOutputWriter(io); + out.write(objectWriter.writeValueAsBytes(e)); + } + } } diff --git a/server/src/main/java/org/apache/druid/server/QueryResultPusher.java b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java new file mode 100644 index 00000000000..44e1f5d6c23 --- /dev/null +++ b/server/src/main/java/org/apache/druid/server/QueryResultPusher.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.io.CountingOutputStream; +import org.apache.druid.client.DirectDruidClient; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.RE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.Accumulator; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.QueryException; +import org.apache.druid.query.QueryInterruptedException; +import org.apache.druid.query.TruncatedResponseContextException; +import org.apache.druid.query.context.ResponseContext; +import org.apache.druid.server.security.ForbiddenException; + +import javax.annotation.Nullable; +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.io.Closeable; +import java.io.IOException; +import java.io.OutputStream; + +public abstract class QueryResultPusher +{ + private static final Logger log = new Logger(QueryResultPusher.class); + + private final HttpServletResponse response; + private final String queryId; + private final ObjectMapper jsonMapper; + private final ResponseContextConfig responseContextConfig; + private final DruidNode selfNode; + private final QueryResource.QueryMetricCounter counter; + private final MediaType contentType; + + private StreamingHttpResponseAccumulator accumulator = null; + + public QueryResultPusher( + HttpServletResponse response, + ObjectMapper jsonMapper, + ResponseContextConfig responseContextConfig, + DruidNode selfNode, + QueryResource.QueryMetricCounter counter, + String queryId, + MediaType contentType + ) + { + this.response = response; + this.queryId = queryId; + this.jsonMapper = jsonMapper; + this.responseContextConfig = responseContextConfig; + this.selfNode = selfNode; + this.counter = counter; + this.contentType = contentType; + } + + /** + * Builds a ResultsWriter to start the lifecycle of the QueryResultPusher. The ResultsWriter encapsulates the logic + * to run the query, serialize it and also report success/failure. + *

    + * This response must not be null. The job of this ResultsWriter is largely to provide lifecycle management to + * the query running and reporting, so this object must never be null. + *

    + * This start() method should do as little work as possible, it should really just make the ResultsWriter and return. + * + * @return a new ResultsWriter + */ + public abstract ResultsWriter start(); + + public abstract void writeException(Exception e, OutputStream out) throws IOException; + + public void push() + { + response.setHeader(QueryResource.QUERY_ID_RESPONSE_HEADER, queryId); + + ResultsWriter resultsWriter = null; + try { + resultsWriter = start(); + + + final QueryResponse queryResponse = resultsWriter.start(response); + if (queryResponse == null) { + // It's already been handled... + return; + } + + final Sequence results = queryResponse.getResults(); + + accumulator = new StreamingHttpResponseAccumulator(queryResponse.getResponseContext(), resultsWriter); + + results.accumulate(null, accumulator); + accumulator.flush(); + + counter.incrementSuccess(); + accumulator.close(); + resultsWriter.recordSuccess(accumulator.getNumBytesSent()); + } + catch (QueryException e) { + handleQueryException(resultsWriter, e); + return; + } + catch (RuntimeException re) { + if (re instanceof ForbiddenException) { + // Forbidden exceptions are special, they get thrown instead of serialized. They happen before the response + // has been committed because the response is committed after results are returned. And, if we started + // returning results before a ForbiddenException gets thrown, that means that we've already leaked stuff + // that should not have been leaked. I.e. it means, we haven't validated the authorization early enough. + if (response.isCommitted()) { + log.error(re, "Got a forbidden exception for query[%s] after the response was already committed.", queryId); + } + throw re; + } + handleQueryException(resultsWriter, new QueryInterruptedException(re)); + return; + } + catch (IOException ioEx) { + handleQueryException(resultsWriter, new QueryInterruptedException(ioEx)); + return; + } + finally { + if (accumulator != null) { + try { + accumulator.close(); + } + catch (IOException e) { + log.warn(e, "Suppressing exception closing accumulator for query[%s]", queryId); + } + } + if (resultsWriter == null) { + log.warn("resultsWriter was null for query[%s], work was maybe done in start() that shouldn't be.", queryId); + } else { + try { + resultsWriter.close(); + } + catch (IOException e) { + log.warn(e, "Suppressing exception closing accumulator for query[%s]", queryId); + } + } + } + } + + private void handleQueryException(ResultsWriter resultsWriter, QueryException e) + { + if (accumulator != null && accumulator.isInitialized()) { + // We already started sending a response when we got the error message. In this case we just give up + // and hope that the partial stream generates a meaningful failure message for our client. We could consider + // also throwing the exception body into the response to make it easier for the client to choke if it manages + // to parse a meaningful object out, but that's potentially an API change so we leave that as an exercise for + // the future. + + resultsWriter.recordFailure(e); + + // This case is always a failure because the error happened mid-stream of sending results back. Therefore, + // we do not believe that the response stream was actually useable + counter.incrementFailed(); + return; + } + + if (response.isCommitted()) { + QueryResource.NO_STACK_LOGGER.warn(e, "Response was committed without the accumulator writing anything!?"); + } + + final QueryException.FailType failType = e.getFailType(); + switch (failType) { + case USER_ERROR: + case UNAUTHORIZED: + case QUERY_RUNTIME_FAILURE: + case CANCELED: + counter.incrementInterrupted(); + break; + case CAPACITY_EXCEEDED: + case UNSUPPORTED: + counter.incrementFailed(); + break; + case TIMEOUT: + counter.incrementTimedOut(); + break; + case UNKNOWN: + log.warn( + e, + "Unknown errorCode[%s], support needs to be added for error handling.", + e.getErrorCode() + ); + counter.incrementFailed(); + } + final int responseStatus = failType.getExpectedStatus(); + + response.setStatus(responseStatus); + response.setHeader("Content-Type", contentType.toString()); + try (ServletOutputStream out = response.getOutputStream()) { + writeException(e, out); + } + catch (IOException ioException) { + log.warn( + ioException, + "Suppressing IOException thrown sending error response for query[%s]", + queryId + ); + } + + resultsWriter.recordFailure(e); + } + + public interface ResultsWriter extends Closeable + { + /** + * Runs the query and returns a ResultsWriter from running the query. + *

    + * This also serves as a hook for any logic that runs on the metadata from a QueryResponse. If this method + * returns {@code null} then the Pusher believes that the response was already handled and skips the rest + * of its logic. As such, any implementation that returns null must make sure that the response has been set + * with a meaningful status, etc. + *

    + * Even if this method returns null, close() should still be called on this object. + * + * @return QueryResponse or null if no more work to do. + */ + @Nullable + QueryResponse start(HttpServletResponse response); + + Writer makeWriter(OutputStream out) throws IOException; + + void recordSuccess(long numBytes); + + void recordFailure(Exception e); + } + + public interface Writer extends Closeable + { + /** + * Start of the response, called once per writer. + */ + void writeResponseStart() throws IOException; + + /** + * Write a row + * + * @param obj object representing the row + */ + void writeRow(Object obj) throws IOException; + + /** + * End of the response. Must allow the user to know that they have read all data successfully. + */ + void writeResponseEnd() throws IOException; + } + + public class StreamingHttpResponseAccumulator implements Accumulator, Closeable + { + private final ResponseContext responseContext; + private final ResultsWriter resultsWriter; + + private boolean closed = false; + private boolean initialized = false; + private CountingOutputStream out = null; + private Writer writer = null; + + public StreamingHttpResponseAccumulator( + ResponseContext responseContext, + ResultsWriter resultsWriter + ) + { + this.responseContext = responseContext; + this.resultsWriter = resultsWriter; + } + + public long getNumBytesSent() + { + return out == null ? 0 : out.getCount(); + } + + public boolean isInitialized() + { + return initialized; + } + + /** + * Initializes the response. This is done lazily so that we can put various metadata that we only get once + * we have some of the response stream into the result. + *

    + * This is called once for each result object, but should only actually happen once. + * + * @return boolean if initialization occurred. False most of the team because initialization only happens once. + */ + public void initialize() + { + if (closed) { + throw new ISE("Cannot reinitialize after closing."); + } + + if (!initialized) { + response.setStatus(HttpServletResponse.SC_OK); + + Object entityTag = responseContext.remove(ResponseContext.Keys.ETAG); + if (entityTag != null) { + response.setHeader(QueryResource.HEADER_ETAG, entityTag.toString()); + } + + DirectDruidClient.removeMagicResponseContextFields(responseContext); + + // Limit the response-context header, see https://github.com/apache/druid/issues/2331 + // Note that Response.ResponseBuilder.header(String key,Object value).build() calls value.toString() + // and encodes the string using ASCII, so 1 char is = 1 byte + ResponseContext.SerializationResult serializationResult; + try { + serializationResult = responseContext.serializeWith( + jsonMapper, + responseContextConfig.getMaxResponseContextHeaderSize() + ); + } + catch (JsonProcessingException e) { + QueryResource.log.info(e, "Problem serializing to JSON!?"); + serializationResult = new ResponseContext.SerializationResult("Could not serialize", "Could not serialize"); + } + + if (serializationResult.isTruncated()) { + final String logToPrint = StringUtils.format( + "Response Context truncated for id [%s]. Full context is [%s].", + queryId, + serializationResult.getFullResult() + ); + if (responseContextConfig.shouldFailOnTruncatedResponseContext()) { + QueryResource.log.error(logToPrint); + throw new QueryInterruptedException( + new TruncatedResponseContextException( + "Serialized response context exceeds the max size[%s]", + responseContextConfig.getMaxResponseContextHeaderSize() + ), + selfNode.getHostAndPortToUse() + ); + } else { + QueryResource.log.warn(logToPrint); + } + } + + response.setHeader(QueryResource.HEADER_RESPONSE_CONTEXT, serializationResult.getResult()); + response.setHeader("Content-Type", contentType.toString()); + + try { + out = new CountingOutputStream(response.getOutputStream()); + writer = resultsWriter.makeWriter(out); + } + catch (IOException e) { + throw new RE(e, "Problems setting up response stream for query[%s]!?", queryId); + } + + try { + writer.writeResponseStart(); + } + catch (IOException e) { + throw new RE(e, "Could not start the response for query[%s]!?", queryId); + } + + initialized = true; + } + } + + @Override + public Response accumulate(Response retVal, Object in) + { + if (!initialized) { + initialize(); + } + + try { + writer.writeRow(in); + } + catch (IOException ex) { + QueryResource.NO_STACK_LOGGER.warn(ex, "Unable to write query response."); + throw new RuntimeException(ex); + } + return null; + } + + public void flush() throws IOException + { + if (!initialized) { + initialize(); + } + writer.writeResponseEnd(); + } + + @Override + public void close() throws IOException + { + if (closed) { + return; + } + if (initialized && writer != null) { + writer.close(); + } + closed = true; + } + } +} diff --git a/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java b/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java index 8bd7d9fb23d..5ec56983649 100644 --- a/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java +++ b/server/src/main/java/org/apache/druid/server/security/PreResponseAuthorizationCheckFilter.java @@ -22,6 +22,7 @@ package org.apache.druid.server.security; import com.fasterxml.jackson.databind.ObjectMapper; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.emitter.EmittingLogger; +import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.server.DruidNode; @@ -126,7 +127,7 @@ public class PreResponseAuthorizationCheckFilter implements Filter response.addHeader("WWW-Authenticate", authScheme); } QueryInterruptedException unauthorizedError = new QueryInterruptedException( - QueryInterruptedException.UNAUTHORIZED, + QueryException.UNAUTHORIZED_ERROR_CODE, null, null, DruidNode.getDefaultHost() diff --git a/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java b/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java index b6010c7b94e..f896067646c 100644 --- a/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java +++ b/server/src/test/java/org/apache/druid/client/JsonParserIteratorTest.java @@ -75,7 +75,7 @@ public class JsonParserIteratorTest JAVA_TYPE, Futures.immediateFailedFuture( new QueryException( - QueryTimeoutException.ERROR_CODE, + QueryException.QUERY_TIMEOUT_ERROR_CODE, "timeout exception conversion test", null, HOST diff --git a/server/src/test/java/org/apache/druid/server/QueryResourceTest.java b/server/src/test/java/org/apache/druid/server/QueryResourceTest.java index c6cae9e1e1b..fc739008958 100644 --- a/server/src/test/java/org/apache/druid/server/QueryResourceTest.java +++ b/server/src/test/java/org/apache/druid/server/QueryResourceTest.java @@ -26,6 +26,7 @@ import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.MoreExecutors; import com.google.inject.Injector; @@ -59,6 +60,9 @@ import org.apache.druid.query.timeboundary.TimeBoundaryResultValue; import org.apache.druid.server.initialization.ServerConfig; import org.apache.druid.server.log.TestRequestLogger; import org.apache.druid.server.metrics.NoopServiceEmitter; +import org.apache.druid.server.mocks.ExceptionalInputStream; +import org.apache.druid.server.mocks.MockHttpServletRequest; +import org.apache.druid.server.mocks.MockHttpServletResponse; import org.apache.druid.server.scheduling.HiLoQueryLaningStrategy; import org.apache.druid.server.scheduling.ManualQueryPrioritizationStrategy; import org.apache.druid.server.scheduling.NoQueryLaningStrategy; @@ -73,21 +77,18 @@ import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.server.security.Resource; import org.apache.http.HttpStatus; -import org.easymock.EasyMock; import org.joda.time.Interval; -import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; -import javax.servlet.http.HttpServletRequest; +import javax.annotation.Nonnull; +import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; -import javax.ws.rs.core.StreamingOutput; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collection; @@ -95,6 +96,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; public class QueryResourceTest @@ -103,7 +105,7 @@ public class QueryResourceTest private static final AuthenticationResult AUTHENTICATION_RESULT = new AuthenticationResult("druid", "druid", null, null); - private final HttpServletRequest testServletRequest = EasyMock.createMock(HttpServletRequest.class); + private final MockHttpServletRequest testServletRequest = new MockHttpServletRequest(); private static final QuerySegmentWalker TEST_SEGMENT_WALKER = new QuerySegmentWalker() { @@ -200,10 +202,10 @@ public class QueryResourceTest jsonMapper = injector.getInstance(ObjectMapper.class); smileMapper = injector.getInstance(Key.get(ObjectMapper.class, Smile.class)); - EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON).anyTimes(); - EasyMock.expect(testServletRequest.getHeader("Accept")).andReturn(MediaType.APPLICATION_JSON).anyTimes(); - EasyMock.expect(testServletRequest.getHeader(QueryResource.HEADER_IF_NONE_MATCH)).andReturn(null).anyTimes(); - EasyMock.expect(testServletRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); + testServletRequest.contentType = MediaType.APPLICATION_JSON; + testServletRequest.headers.put("Accept", MediaType.APPLICATION_JSON); + testServletRequest.remoteAddr = "localhost"; + queryScheduler = QueryStackTests.DEFAULT_NOOP_SCHEDULER; testRequestLogger = new TestRequestLogger(); queryResource = createQueryResource(ResponseContextConfig.newConfig(true)); @@ -232,23 +234,12 @@ public class QueryResourceTest ); } - @After - public void tearDown() - { - EasyMock.verify(testServletRequest); - } - @Test public void testGoodQuery() throws IOException { expectPermissiveHappyPathAuth(); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest - ); - Assert.assertNotNull(response); + Assert.assertEquals(200, expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY).getStatus()); } @Test @@ -279,28 +270,29 @@ public class QueryResourceTest expectPermissiveHappyPathAuth(); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest - ); - Assert.assertNotNull(response); - - final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ((StreamingOutput) response.getEntity()).write(baos); - final List> responses = jsonMapper.readValue( - baos.toByteArray(), - new TypeReference>>() {} - ); - - Assert.assertNotNull(response); + final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY); Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + + final List> responses = jsonMapper.readValue( + response.baos.toByteArray(), + new TypeReference>>() + { + } + ); + Assert.assertEquals(0, responses.size()); Assert.assertEquals(1, testRequestLogger.getNativeQuerylogs().size()); Assert.assertNotNull(testRequestLogger.getNativeQuerylogs().get(0).getQuery()); Assert.assertNotNull(testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext()); - Assert.assertTrue(testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext().containsKey(overrideConfigKey)); - Assert.assertEquals(overrideConfigValue, testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext().get(overrideConfigKey)); + Assert.assertTrue(testRequestLogger.getNativeQuerylogs() + .get(0) + .getQuery() + .getContext() + .containsKey(overrideConfigKey)); + Assert.assertEquals( + overrideConfigValue, + testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext().get(overrideConfigKey) + ); } @Test @@ -331,19 +323,13 @@ public class QueryResourceTest expectPermissiveHappyPathAuth(); - Response response = queryResource.doPost( - // SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY context has overrideConfigKey with value of -1 - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest - ); - Assert.assertNotNull(response); + final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY_LOW_PRIORITY); - final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ((StreamingOutput) response.getEntity()).write(baos); final List> responses = jsonMapper.readValue( - baos.toByteArray(), - new TypeReference>>() {} + response.baos.toByteArray(), + new TypeReference>>() + { + } ); Assert.assertNotNull(response); @@ -352,23 +338,30 @@ public class QueryResourceTest Assert.assertEquals(1, testRequestLogger.getNativeQuerylogs().size()); Assert.assertNotNull(testRequestLogger.getNativeQuerylogs().get(0).getQuery()); Assert.assertNotNull(testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext()); - Assert.assertTrue(testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext().containsKey(overrideConfigKey)); - Assert.assertEquals(-1, testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext().get(overrideConfigKey)); + Assert.assertTrue(testRequestLogger.getNativeQuerylogs() + .get(0) + .getQuery() + .getContext() + .containsKey(overrideConfigKey)); + Assert.assertEquals( + -1, + testRequestLogger.getNativeQuerylogs().get(0).getQuery().getContext().get(overrideConfigKey) + ); } @Test public void testTruncatedResponseContextShouldFail() throws IOException { expectPermissiveHappyPathAuth(); + final QueryResource queryResource = createQueryResource(ResponseContextConfig.forTest(true, 0)); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest + MockHttpServletResponse response = expectAsyncRequestFlow( + testServletRequest, + SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8), + queryResource ); Assert.assertEquals(1, queryResource.getInterruptedQueryCount()); - Assert.assertNotNull(response); Assert.assertEquals(HttpStatus.SC_INTERNAL_SERVER_ERROR, response.getStatus()); final String expectedException = new QueryInterruptedException( new TruncatedResponseContextException("Serialized response context exceeds the max size[0]"), @@ -376,7 +369,7 @@ public class QueryResourceTest ).toString(); Assert.assertEquals( expectedException, - jsonMapper.readValue((byte[]) response.getEntity(), QueryInterruptedException.class).toString() + jsonMapper.readValue(response.baos.toByteArray(), QueryInterruptedException.class).toString() ); } @@ -386,223 +379,105 @@ public class QueryResourceTest expectPermissiveHappyPathAuth(); final QueryResource queryResource = createQueryResource(ResponseContextConfig.forTest(false, 0)); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest + final MockHttpServletResponse response = expectAsyncRequestFlow( + testServletRequest, + SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8), + queryResource ); - Assert.assertNotNull(response); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); } @Test public void testGoodQueryWithNullAcceptHeader() throws IOException { - final String acceptHeader = null; - final String contentTypeHeader = MediaType.APPLICATION_JSON; - EasyMock.reset(testServletRequest); + testServletRequest.headers.remove("Accept"); + expectPermissiveHappyPathAuth(); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - - EasyMock.expect(testServletRequest.getHeader("Accept")).andReturn(acceptHeader).anyTimes(); - EasyMock.expect(testServletRequest.getContentType()).andReturn(contentTypeHeader).anyTimes(); - EasyMock.expect(testServletRequest.getHeader(QueryResource.HEADER_IF_NONE_MATCH)).andReturn(null).anyTimes(); - EasyMock.expect(testServletRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); - - EasyMock.replay(testServletRequest); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest - ); + final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); //since accept header is null, the response content type should be same as the value of 'Content-Type' header - Assert.assertEquals(contentTypeHeader, (response.getMetadata().get("Content-Type").get(0)).toString()); - Assert.assertNotNull(response); + Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type"))); } @Test public void testGoodQueryWithEmptyAcceptHeader() throws IOException { - final String acceptHeader = ""; - final String contentTypeHeader = MediaType.APPLICATION_JSON; - EasyMock.reset(testServletRequest); + expectPermissiveHappyPathAuth(); + testServletRequest.headers.put("Accept", ""); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); + final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - - EasyMock.expect(testServletRequest.getHeader("Accept")).andReturn(acceptHeader).anyTimes(); - EasyMock.expect(testServletRequest.getContentType()).andReturn(contentTypeHeader).anyTimes(); - EasyMock.expect(testServletRequest.getHeader(QueryResource.HEADER_IF_NONE_MATCH)).andReturn(null).anyTimes(); - EasyMock.expect(testServletRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); - - EasyMock.replay(testServletRequest); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest - ); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); //since accept header is empty, the response content type should be same as the value of 'Content-Type' header - Assert.assertEquals(contentTypeHeader, (response.getMetadata().get("Content-Type").get(0)).toString()); - Assert.assertNotNull(response); + Assert.assertEquals(MediaType.APPLICATION_JSON, Iterables.getOnlyElement(response.headers.get("Content-Type"))); } @Test public void testGoodQueryWithJsonRequestAndSmileAcceptHeader() throws IOException { - //Doing a replay of testServletRequest for teardown to succeed. - //We dont use testServletRequest in this testcase - EasyMock.replay(testServletRequest); - - //Creating our own Smile Servlet request, as to not disturb the remaining tests. - // else refactoring required for this class. i know this kinda makes the class somewhat Dirty. - final HttpServletRequest smileRequest = EasyMock.createMock(HttpServletRequest.class); - - // Set Content-Type to JSON - EasyMock.expect(smileRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON).anyTimes(); - - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - smileRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); + expectPermissiveHappyPathAuth(); // Set Accept to Smile - EasyMock.expect(smileRequest.getHeader("Accept")).andReturn(SmileMediaTypes.APPLICATION_JACKSON_SMILE).anyTimes(); - EasyMock.expect(smileRequest.getHeader(QueryResource.HEADER_IF_NONE_MATCH)).andReturn(null).anyTimes(); - EasyMock.expect(smileRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); + testServletRequest.headers.put("Accept", SmileMediaTypes.APPLICATION_JACKSON_SMILE); - EasyMock.replay(smileRequest); - Response response = queryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - smileRequest - ); + final MockHttpServletResponse response = expectAsyncRequestFlow(SIMPLE_TIMESERIES_QUERY); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); // Content-Type in response should be Smile - Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, (response.getMetadata().get("Content-Type").get(0)).toString()); - Assert.assertNotNull(response); - EasyMock.verify(smileRequest); + Assert.assertEquals( + SmileMediaTypes.APPLICATION_JACKSON_SMILE, + Iterables.getOnlyElement(response.headers.get("Content-Type")) + ); } @Test public void testGoodQueryWithSmileRequestAndSmileAcceptHeader() throws IOException { - //Doing a replay of testServletRequest for teardown to succeed. - //We dont use testServletRequest in this testcase - EasyMock.replay(testServletRequest); - - //Creating our own Smile Servlet request, as to not disturb the remaining tests. - // else refactoring required for this class. i know this kinda makes the class somewhat Dirty. - final HttpServletRequest smileRequest = EasyMock.createMock(HttpServletRequest.class); - - // Set Content-Type to Smile - EasyMock.expect(smileRequest.getContentType()).andReturn(SmileMediaTypes.APPLICATION_JACKSON_SMILE).anyTimes(); - - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - smileRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); + testServletRequest.contentType = SmileMediaTypes.APPLICATION_JACKSON_SMILE; + expectPermissiveHappyPathAuth(); // Set Accept to Smile - EasyMock.expect(smileRequest.getHeader("Accept")).andReturn(SmileMediaTypes.APPLICATION_JACKSON_SMILE).anyTimes(); - EasyMock.expect(smileRequest.getHeader(QueryResource.HEADER_IF_NONE_MATCH)).andReturn(null).anyTimes(); - EasyMock.expect(smileRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); + testServletRequest.headers.put("Accept", SmileMediaTypes.APPLICATION_JACKSON_SMILE); - EasyMock.replay(smileRequest); - Response response = queryResource.doPost( - // Write input in Smile encoding - new ByteArrayInputStream(smileMapper.writeValueAsBytes(jsonMapper.readTree(SIMPLE_TIMESERIES_QUERY))), - null /*pretty*/, - smileRequest + final MockHttpServletResponse response = expectAsyncRequestFlow( + testServletRequest, + smileMapper.writeValueAsBytes(jsonMapper.readTree( + SIMPLE_TIMESERIES_QUERY)) ); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); // Content-Type in response should be Smile - Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, (response.getMetadata().get("Content-Type").get(0)).toString()); - Assert.assertNotNull(response); - EasyMock.verify(smileRequest); + Assert.assertEquals( + SmileMediaTypes.APPLICATION_JACKSON_SMILE, + Iterables.getOnlyElement(response.headers.get("Content-Type")) + ); } @Test public void testGoodQueryWithSmileRequestNoSmileAcceptHeader() throws IOException { - //Doing a replay of testServletRequest for teardown to succeed. - //We dont use testServletRequest in this testcase - EasyMock.replay(testServletRequest); - - //Creating our own Smile Servlet request, as to not disturb the remaining tests. - // else refactoring required for this class. i know this kinda makes the class somewhat Dirty. - final HttpServletRequest smileRequest = EasyMock.createMock(HttpServletRequest.class); - - // Set Content-Type to Smile - EasyMock.expect(smileRequest.getContentType()).andReturn(SmileMediaTypes.APPLICATION_JACKSON_SMILE).anyTimes(); - - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(smileRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - smileRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); + testServletRequest.contentType = SmileMediaTypes.APPLICATION_JACKSON_SMILE; + expectPermissiveHappyPathAuth(); // DO NOT set Accept to Smile, Content-Type in response will be default to Content-Type in request - EasyMock.expect(smileRequest.getHeader("Accept")).andReturn(null).anyTimes(); - EasyMock.expect(smileRequest.getHeader(QueryResource.HEADER_IF_NONE_MATCH)).andReturn(null).anyTimes(); - EasyMock.expect(smileRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); + testServletRequest.headers.remove("Accept"); - EasyMock.replay(smileRequest); - Response response = queryResource.doPost( - // Write input in Smile encoding - new ByteArrayInputStream(smileMapper.writeValueAsBytes(jsonMapper.readTree(SIMPLE_TIMESERIES_QUERY))), - null /*pretty*/, - smileRequest + final MockHttpServletResponse response = expectAsyncRequestFlow( + testServletRequest, + smileMapper.writeValueAsBytes(jsonMapper.readTree(SIMPLE_TIMESERIES_QUERY)) ); Assert.assertEquals(HttpStatus.SC_OK, response.getStatus()); - // Content-Type in response will be default to Content-Type in request - Assert.assertEquals(SmileMediaTypes.APPLICATION_JACKSON_SMILE, (response.getMetadata().get("Content-Type").get(0)).toString()); - Assert.assertNotNull(response); - EasyMock.verify(smileRequest); + // Content-Type in response should default to Content-Type from request + Assert.assertEquals( + SmileMediaTypes.APPLICATION_JACKSON_SMILE, + Iterables.getOnlyElement(response.headers.get("Content-Type")) + ); } @Test public void testBadQuery() throws IOException { - EasyMock.replay(testServletRequest); Response response = queryResource.doPost( new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes(StandardCharsets.UTF_8)), null /*pretty*/, @@ -611,26 +486,22 @@ public class QueryResourceTest Assert.assertNotNull(response); Assert.assertEquals(Status.BAD_REQUEST.getStatusCode(), response.getStatus()); QueryException e = jsonMapper.readValue((byte[]) response.getEntity(), QueryException.class); - Assert.assertEquals(BadJsonQueryException.ERROR_CODE, e.getErrorCode()); + Assert.assertEquals(QueryException.JSON_PARSE_ERROR_CODE, e.getErrorCode()); Assert.assertEquals(BadJsonQueryException.ERROR_CLASS, e.getErrorClass()); } @Test public void testResourceLimitExceeded() throws IOException { - ByteArrayInputStream badQuery = EasyMock.createMock(ByteArrayInputStream.class); - EasyMock.expect(badQuery.read(EasyMock.anyObject(), EasyMock.anyInt(), EasyMock.anyInt())) - .andThrow(new ResourceLimitExceededException("You require too much of something")); - EasyMock.replay(badQuery, testServletRequest); Response response = queryResource.doPost( - badQuery, + new ExceptionalInputStream(() -> new ResourceLimitExceededException("You require too much of something")), null /*pretty*/, testServletRequest ); Assert.assertNotNull(response); Assert.assertEquals(Status.BAD_REQUEST.getStatusCode(), response.getStatus()); QueryException e = jsonMapper.readValue((byte[]) response.getEntity(), QueryException.class); - Assert.assertEquals(ResourceLimitExceededException.ERROR_CODE, e.getErrorCode()); + Assert.assertEquals(QueryException.RESOURCE_LIMIT_EXCEEDED_ERROR_CODE, e.getErrorCode()); Assert.assertEquals(ResourceLimitExceededException.class.getName(), e.getErrorClass()); } @@ -638,13 +509,8 @@ public class QueryResourceTest public void testUnsupportedQueryThrowsException() throws IOException { String errorMessage = "This will be support in Druid 9999"; - ByteArrayInputStream badQuery = EasyMock.createMock(ByteArrayInputStream.class); - EasyMock.expect(badQuery.read(EasyMock.anyObject(), EasyMock.anyInt(), EasyMock.anyInt())).andThrow( - new QueryUnsupportedException(errorMessage)); - EasyMock.replay(badQuery); - EasyMock.replay(testServletRequest); Response response = queryResource.doPost( - badQuery, + new ExceptionalInputStream(() -> new QueryUnsupportedException(errorMessage)), null /*pretty*/, testServletRequest ); @@ -652,28 +518,13 @@ public class QueryResourceTest Assert.assertEquals(QueryUnsupportedException.STATUS_CODE, response.getStatus()); QueryException ex = jsonMapper.readValue((byte[]) response.getEntity(), QueryException.class); Assert.assertEquals(errorMessage, ex.getMessage()); - Assert.assertEquals(QueryUnsupportedException.ERROR_CODE, ex.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_UNSUPPORTED_ERROR_CODE, ex.getErrorCode()); } @Test public void testSecuredQuery() throws Exception { - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, false); - EasyMock.expectLastCall().times(1); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - EasyMock.expectLastCall().times(1); - - EasyMock.replay(testServletRequest); + expectPermissiveHappyPathAuth(); AuthorizerMapper authMapper = new AuthorizerMapper(null) { @@ -721,27 +572,26 @@ public class QueryResourceTest queryResource.doPost( new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), null /*pretty*/, - testServletRequest + testServletRequest.mimic() ); Assert.fail("doPost did not throw ForbiddenException for an unauthorized query"); } catch (ForbiddenException e) { } - Response response = queryResource.doPost( - new ByteArrayInputStream("{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"}".getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest + final MockHttpServletResponse response = expectAsyncRequestFlow( + "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"}", + testServletRequest.mimic() ); - - final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ((StreamingOutput) response.getEntity()).write(baos); - final List> responses = jsonMapper.readValue( - baos.toByteArray(), - new TypeReference>>() {} - ); - Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + + final List> responses = jsonMapper.readValue( + response.baos.toByteArray(), + new TypeReference>>() + { + } + ); + Assert.assertEquals(0, responses.size()); Assert.assertEquals(1, testRequestLogger.getNativeQuerylogs().size()); Assert.assertEquals( @@ -792,22 +642,16 @@ public class QueryResourceTest DRUID_NODE ); expectPermissiveHappyPathAuth(); - Response response = timeoutQueryResource.doPost( - new ByteArrayInputStream(SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8)), - null /*pretty*/, - testServletRequest + + final MockHttpServletResponse response = expectAsyncRequestFlow( + testServletRequest, + SIMPLE_TIMESERIES_QUERY.getBytes(StandardCharsets.UTF_8), + timeoutQueryResource ); - Assert.assertNotNull(response); Assert.assertEquals(QueryTimeoutException.STATUS_CODE, response.getStatus()); - QueryTimeoutException ex; - try { - ex = jsonMapper.readValue((byte[]) response.getEntity(), QueryTimeoutException.class); - } - catch (IOException e) { - throw new RuntimeException(e); - } + QueryTimeoutException ex = jsonMapper.readValue(response.baos.toByteArray(), QueryTimeoutException.class); Assert.assertEquals("Query Timed Out!", ex.getMessage()); - Assert.assertEquals(QueryTimeoutException.ERROR_CODE, ex.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_TIMEOUT_ERROR_CODE, ex.getErrorCode()); Assert.assertEquals(1, timeoutQueryResource.getTimedOutQueryCount()); } @@ -820,19 +664,7 @@ public class QueryResourceTest final CountDownLatch startAwaitLatch = new CountDownLatch(1); final CountDownLatch cancelledCountDownLatch = new CountDownLatch(1); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - EasyMock.expectLastCall().times(1); - - EasyMock.replay(testServletRequest); + expectPermissiveHappyPathAuth(); AuthorizerMapper authMapper = new AuthorizerMapper(null) { @@ -858,7 +690,7 @@ public class QueryResourceTest // When the query is cancelled the control will reach here, // countdown the latch and rethrow the exception so that error response is returned for the query cancelledCountDownLatch.countDown(); - throw new RuntimeException(e); + throw new QueryInterruptedException(e); } return new Access(true); } else { @@ -895,26 +727,25 @@ public class QueryResourceTest ObjectMapper mapper = new DefaultObjectMapper(); Query query = mapper.readValue(queryString, Query.class); - ListenableFuture future = MoreExecutors.listeningDecorator( + AtomicReference responseFromEndpoint = new AtomicReference<>(); + + // We expect this future to get canceled so we have to grab the exception somewhere else. + ListenableFuture future = MoreExecutors.listeningDecorator( Execs.singleThreaded("test_query_resource_%s") ).submit( - new Runnable() - { - @Override - public void run() - { - try { - Response response = queryResource.doPost( - new ByteArrayInputStream(queryString.getBytes(StandardCharsets.UTF_8)), - null, - testServletRequest - ); - - Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); - } - catch (IOException e) { - throw new RuntimeException(e); - } + () -> { + try { + responseFromEndpoint.set(queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes(StandardCharsets.UTF_8)), + null, + testServletRequest + )); + return null; + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { waitFinishLatch.countDown(); } } @@ -924,19 +755,19 @@ public class QueryResourceTest startAwaitLatch.await(); Executors.newSingleThreadExecutor().submit( - new Runnable() { - @Override - public void run() - { - Response response = queryResource.cancelQuery("id_1", testServletRequest); - Assert.assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); - waitForCancellationLatch.countDown(); - waitFinishLatch.countDown(); - } + () -> { + Response response = queryResource.cancelQuery("id_1", testServletRequest); + Assert.assertEquals(Status.ACCEPTED.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); } ); waitFinishLatch.await(); cancelledCountDownLatch.await(); + + Assert.assertTrue(future.isCancelled()); + final Response response = responseFromEndpoint.get(); + Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); } @Test(timeout = 60_000L) @@ -946,23 +777,7 @@ public class QueryResourceTest final CountDownLatch waitFinishLatch = new CountDownLatch(2); final CountDownLatch startAwaitLatch = new CountDownLatch(1); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - EasyMock.expectLastCall().times(1); - - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, false); - EasyMock.expectLastCall().times(1); - - EasyMock.replay(testServletRequest); + expectPermissiveHappyPathAuth(); AuthorizerMapper authMapper = new AuthorizerMapper(null) { @@ -1019,26 +834,25 @@ public class QueryResourceTest ObjectMapper mapper = new DefaultObjectMapper(); Query query = mapper.readValue(queryString, Query.class); - ListenableFuture future = MoreExecutors.listeningDecorator( + ListenableFuture future = MoreExecutors.listeningDecorator( Execs.singleThreaded("test_query_resource_%s") ).submit( - new Runnable() - { - @Override - public void run() - { - try { - startAwaitLatch.countDown(); - Response response = queryResource.doPost( - new ByteArrayInputStream(queryString.getBytes(StandardCharsets.UTF_8)), - null, - testServletRequest - ); - Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); - } - catch (IOException e) { - throw new RuntimeException(e); - } + () -> { + try { + startAwaitLatch.countDown(); + final MockHttpServletRequest localRequest = testServletRequest.mimic(); + final MockHttpServletResponse retVal = MockHttpServletResponse.forRequest(localRequest); + queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes(StandardCharsets.UTF_8)), + null, + localRequest + ); + return retVal; + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { waitFinishLatch.countDown(); } } @@ -1048,22 +862,19 @@ public class QueryResourceTest startAwaitLatch.await(); Executors.newSingleThreadExecutor().submit( - new Runnable() - { - @Override - public void run() - { - try { - queryResource.cancelQuery("id_1", testServletRequest); - } - catch (ForbiddenException e) { - waitForCancellationLatch.countDown(); - waitFinishLatch.countDown(); - } + () -> { + try { + queryResource.cancelQuery("id_1", testServletRequest.mimic()); + } + catch (ForbiddenException e) { + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); } } ); waitFinishLatch.await(); + + Assert.assertEquals(Response.Status.OK.getStatusCode(), future.get().getStatus()); } @Test(timeout = 10_000L) @@ -1099,13 +910,13 @@ public class QueryResourceTest Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); QueryCapacityExceededException ex; try { - ex = jsonMapper.readValue((byte[]) response.getEntity(), QueryCapacityExceededException.class); + ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); } catch (IOException e) { throw new RuntimeException(e); } Assert.assertEquals(QueryCapacityExceededException.makeTotalErrorMessage(2), ex.getMessage()); - Assert.assertEquals(QueryCapacityExceededException.ERROR_CODE, ex.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, ex.getErrorCode()); } ); waitAllFinished.await(); @@ -1140,7 +951,7 @@ public class QueryResourceTest Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); QueryCapacityExceededException ex; try { - ex = jsonMapper.readValue((byte[]) response.getEntity(), QueryCapacityExceededException.class); + ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); } catch (IOException e) { throw new RuntimeException(e); @@ -1149,7 +960,7 @@ public class QueryResourceTest QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 1), ex.getMessage() ); - Assert.assertEquals(QueryCapacityExceededException.ERROR_CODE, ex.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, ex.getErrorCode()); } ); @@ -1192,7 +1003,7 @@ public class QueryResourceTest Assert.assertEquals(QueryCapacityExceededException.STATUS_CODE, response.getStatus()); QueryCapacityExceededException ex; try { - ex = jsonMapper.readValue((byte[]) response.getEntity(), QueryCapacityExceededException.class); + ex = jsonMapper.readValue(response.baos.toByteArray(), QueryCapacityExceededException.class); } catch (IOException e) { throw new RuntimeException(e); @@ -1201,7 +1012,7 @@ public class QueryResourceTest QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 1), ex.getMessage() ); - Assert.assertEquals(QueryCapacityExceededException.ERROR_CODE, ex.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, ex.getErrorCode()); } ); waitTwoStarted.await(); @@ -1274,16 +1085,15 @@ public class QueryResourceTest ); } - private void assertResponseAndCountdownOrBlockForever(String query, CountDownLatch done, Consumer asserts) + private void assertResponseAndCountdownOrBlockForever( + String query, + CountDownLatch done, + Consumer asserts + ) { Executors.newSingleThreadExecutor().submit(() -> { try { - Response response = queryResource.doPost( - new ByteArrayInputStream(query.getBytes(StandardCharsets.UTF_8)), - null, - testServletRequest - ); - asserts.accept(response); + asserts.accept(expectAsyncRequestFlow(query, testServletRequest.mimic())); } catch (IOException e) { throw new RuntimeException(e); @@ -1294,18 +1104,46 @@ public class QueryResourceTest private void expectPermissiveHappyPathAuth() { - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); + testServletRequest.setAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT, AUTHENTICATION_RESULT); + } - EasyMock.expect(testServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AUTHENTICATION_RESULT) - .anyTimes(); + @Nonnull + private MockHttpServletResponse expectAsyncRequestFlow(String simpleTimeseriesQuery) throws IOException + { + return expectAsyncRequestFlow( + simpleTimeseriesQuery, + testServletRequest + ); + } - testServletRequest.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - EasyMock.expectLastCall().anyTimes(); + @Nonnull + private MockHttpServletResponse expectAsyncRequestFlow(String query, MockHttpServletRequest req) throws IOException + { + return expectAsyncRequestFlow(req, query.getBytes(StandardCharsets.UTF_8)); + } - EasyMock.replay(testServletRequest); + @Nonnull + private MockHttpServletResponse expectAsyncRequestFlow( + MockHttpServletRequest req, + byte[] queryBytes + ) throws IOException + { + return expectAsyncRequestFlow(req, queryBytes, queryResource); + } + + @Nonnull + private MockHttpServletResponse expectAsyncRequestFlow( + MockHttpServletRequest req, + byte[] queryBytes, QueryResource queryResource + ) throws IOException + { + final MockHttpServletResponse response = MockHttpServletResponse.forRequest(req); + + Assert.assertNull(queryResource.doPost( + new ByteArrayInputStream(queryBytes), + null /*pretty*/, + req + )); + return response; } } diff --git a/server/src/test/java/org/apache/druid/server/coordinator/LoadQueuePeonTest.java b/server/src/test/java/org/apache/druid/server/coordinator/LoadQueuePeonTest.java index 3929aa64d59..6b0f54d7803 100644 --- a/server/src/test/java/org/apache/druid/server/coordinator/LoadQueuePeonTest.java +++ b/server/src/test/java/org/apache/druid/server/coordinator/LoadQueuePeonTest.java @@ -316,9 +316,11 @@ public class LoadQueuePeonTest extends CuratorTestBase jsonMapper, Execs.scheduledSingleThreaded("test_load_queue_peon_scheduled-%d"), Execs.singleThreaded("test_load_queue_peon-%d"), - // set time-out to 1 ms so that LoadQueuePeon will fail the assignment quickly + // The timeout here was set to 1ms, when this test was acting flakey. A cursory glance makes me wonder if + // there's a race where the timeout actually happens before other code can run. 1ms timeout seems aggressive. + // 100ms is a great price to pay if it removes the flakeyness new TestDruidCoordinatorConfig.Builder() - .withLoadTimeoutDelay(new Duration(1)) + .withLoadTimeoutDelay(new Duration(100)) .withCoordinatorKillMaxSegments(10) .withCoordinatorKillIgnoreDurationToRetain(false) .build() diff --git a/server/src/test/java/org/apache/druid/server/mocks/ExceptionalInputStream.java b/server/src/test/java/org/apache/druid/server/mocks/ExceptionalInputStream.java new file mode 100644 index 00000000000..6d7becbed65 --- /dev/null +++ b/server/src/test/java/org/apache/druid/server/mocks/ExceptionalInputStream.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.mocks; + +import org.apache.druid.java.util.common.RE; + +import java.io.IOException; +import java.io.InputStream; +import java.util.function.Supplier; + +public class ExceptionalInputStream extends InputStream +{ + private final Supplier supplier; + + public ExceptionalInputStream( + Supplier supplier + ) + { + this.supplier = supplier; + } + + @Override + public int read() throws IOException + { + final Exception throwMe = supplier.get(); + if (throwMe instanceof RuntimeException) { + throw (RuntimeException) throwMe; + } + if (throwMe instanceof IOException) { + throw (IOException) throwMe; + } + throw new RE(throwMe, "wrapped because cannot throw typed exception"); + } +} diff --git a/server/src/test/java/org/apache/druid/server/mocks/MockAsyncContext.java b/server/src/test/java/org/apache/druid/server/mocks/MockAsyncContext.java new file mode 100644 index 00000000000..4c74acbcb76 --- /dev/null +++ b/server/src/test/java/org/apache/druid/server/mocks/MockAsyncContext.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.mocks; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncListener; +import javax.servlet.ServletContext; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import java.util.concurrent.atomic.AtomicBoolean; + +public class MockAsyncContext implements AsyncContext +{ + public ServletRequest request; + public ServletResponse response; + + private final AtomicBoolean completed = new AtomicBoolean(); + + @Override + public ServletRequest getRequest() + { + if (request == null) { + throw new UnsupportedOperationException(); + } else { + return request; + } + } + + @Override + public ServletResponse getResponse() + { + if (response == null) { + throw new UnsupportedOperationException(); + } else { + return response; + } + } + + @Override + public boolean hasOriginalRequestAndResponse() + { + throw new UnsupportedOperationException(); + } + + @Override + public void dispatch() + { + throw new UnsupportedOperationException(); + } + + @Override + public void dispatch(String path) + { + throw new UnsupportedOperationException(); + } + + @Override + public void dispatch(ServletContext context, String path) + { + throw new UnsupportedOperationException(); + } + + @Override + public void complete() + { + completed.set(true); + } + + public boolean isCompleted() + { + return completed.get(); + } + + @Override + public void start(Runnable run) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addListener(AsyncListener listener) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addListener( + AsyncListener listener, + ServletRequest servletRequest, + ServletResponse servletResponse + ) + { + throw new UnsupportedOperationException(); + } + + @Override + public T createListener(Class clazz) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setTimeout(long timeout) + { + throw new UnsupportedOperationException(); + } + + @Override + public long getTimeout() + { + throw new UnsupportedOperationException(); + } +} diff --git a/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java new file mode 100644 index 00000000000..34a425ea88c --- /dev/null +++ b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletRequest.java @@ -0,0 +1,504 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.mocks; + +import javax.servlet.AsyncContext; +import javax.servlet.DispatcherType; +import javax.servlet.RequestDispatcher; +import javax.servlet.ServletContext; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.Part; +import java.io.BufferedReader; +import java.security.Principal; +import java.util.Collection; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; + +public class MockHttpServletRequest implements HttpServletRequest +{ + public String contentType = null; + public String remoteAddr = null; + + public LinkedHashMap headers = new LinkedHashMap<>(); + public LinkedHashMap attributes = new LinkedHashMap<>(); + + public Supplier asyncContextSupplier; + + private AsyncContext currAsyncContext = null; + + @Override + public String getAuthType() + { + throw new UnsupportedOperationException(); + } + + @Override + public Cookie[] getCookies() + { + throw new UnsupportedOperationException(); + } + + @Override + public long getDateHeader(String name) + { + throw new UnsupportedOperationException(); + } + + @Override + public String getHeader(String name) + { + return headers.get(name); + } + + @Override + public Enumeration getHeaders(String name) + { + throw new UnsupportedOperationException(); + } + + @Override + public Enumeration getHeaderNames() + { + throw new UnsupportedOperationException(); + } + + @Override + public int getIntHeader(String name) + { + throw new UnsupportedOperationException(); + } + + @Override + public String getMethod() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getPathInfo() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getPathTranslated() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getContextPath() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getQueryString() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getRemoteUser() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isUserInRole(String role) + { + throw new UnsupportedOperationException(); + } + + @Override + public Principal getUserPrincipal() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getRequestedSessionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getRequestURI() + { + throw new UnsupportedOperationException(); + } + + @Override + public StringBuffer getRequestURL() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getServletPath() + { + throw new UnsupportedOperationException(); + } + + @Override + public HttpSession getSession(boolean create) + { + throw new UnsupportedOperationException(); + } + + @Override + public HttpSession getSession() + { + throw new UnsupportedOperationException(); + } + + @Override + public String changeSessionId() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRequestedSessionIdValid() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRequestedSessionIdFromCookie() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRequestedSessionIdFromURL() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isRequestedSessionIdFromUrl() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean authenticate(HttpServletResponse response) + { + throw new UnsupportedOperationException(); + } + + @Override + public void login(String username, String password) + { + throw new UnsupportedOperationException(); + } + + @Override + public void logout() + { + throw new UnsupportedOperationException(); + } + + @Override + public Collection getParts() + { + throw new UnsupportedOperationException(); + } + + @Override + public Part getPart(String name) + { + throw new UnsupportedOperationException(); + } + + @Override + public T upgrade(Class handlerClass) + { + throw new UnsupportedOperationException(); + } + + @Override + public Object getAttribute(String name) + { + return attributes.get(name); + } + + @Override + public Enumeration getAttributeNames() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getCharacterEncoding() + { + throw new UnsupportedOperationException(); + } + + @Override + public void setCharacterEncoding(String env) + { + throw new UnsupportedOperationException(); + } + + @Override + public int getContentLength() + { + throw new UnsupportedOperationException(); + } + + @Override + public long getContentLengthLong() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getContentType() + { + return contentType; + } + + @Override + public ServletInputStream getInputStream() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getParameter(String name) + { + throw new UnsupportedOperationException(); + } + + @Override + public Enumeration getParameterNames() + { + throw new UnsupportedOperationException(); + } + + @Override + public String[] getParameterValues(String name) + { + throw new UnsupportedOperationException(); + } + + @Override + public Map getParameterMap() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getProtocol() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getScheme() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getServerName() + { + throw new UnsupportedOperationException(); + } + + @Override + public int getServerPort() + { + throw new UnsupportedOperationException(); + } + + @Override + public BufferedReader getReader() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getRemoteAddr() + { + return remoteAddr; + } + + @Override + public String getRemoteHost() + { + throw new UnsupportedOperationException(); + } + + @Override + public void setAttribute(String name, Object o) + { + attributes.put(name, o); + } + + @Override + public void removeAttribute(String name) + { + attributes.remove(name); + } + + @Override + public Locale getLocale() + { + throw new UnsupportedOperationException(); + } + + @Override + public Enumeration getLocales() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isSecure() + { + throw new UnsupportedOperationException(); + } + + @Override + public RequestDispatcher getRequestDispatcher(String path) + { + throw new UnsupportedOperationException(); + } + + @Override + public String getRealPath(String path) + { + throw new UnsupportedOperationException(); + } + + @Override + public int getRemotePort() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getLocalName() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getLocalAddr() + { + throw new UnsupportedOperationException(); + } + + @Override + public int getLocalPort() + { + throw new UnsupportedOperationException(); + } + + @Override + public ServletContext getServletContext() + { + throw new UnsupportedOperationException(); + } + + @Override + public AsyncContext startAsync() + { + if (asyncContextSupplier == null) { + throw new UnsupportedOperationException(); + } else { + if (currAsyncContext == null) { + currAsyncContext = asyncContextSupplier.get(); + if (currAsyncContext instanceof MockAsyncContext) { + MockAsyncContext mocked = (MockAsyncContext) currAsyncContext; + if (mocked.request == null) { + mocked.request = this; + } + } + } + return currAsyncContext; + } + } + + @Override + public AsyncContext startAsync( + ServletRequest servletRequest, + ServletResponse servletResponse + ) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isAsyncStarted() + { + return currAsyncContext != null; + } + + @Override + public boolean isAsyncSupported() + { + return true; + } + + @Override + public AsyncContext getAsyncContext() + { + return currAsyncContext; + } + + @Override + public DispatcherType getDispatcherType() + { + throw new UnsupportedOperationException(); + } + + public void newAsyncContext(Supplier supplier) + { + asyncContextSupplier = supplier; + currAsyncContext = null; + } + + public MockHttpServletRequest mimic() + { + MockHttpServletRequest retVal = new MockHttpServletRequest(); + + retVal.asyncContextSupplier = asyncContextSupplier; + retVal.attributes.putAll(attributes); + retVal.headers.putAll(headers); + retVal.contentType = contentType; + retVal.remoteAddr = remoteAddr; + + return retVal; + } +} diff --git a/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java new file mode 100644 index 00000000000..a480c0ef656 --- /dev/null +++ b/server/src/test/java/org/apache/druid/server/mocks/MockHttpServletResponse.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.server.mocks; + +import com.google.common.collect.Multimap; +import com.google.common.collect.Multimaps; + +import javax.annotation.Nullable; +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletResponse; +import java.io.ByteArrayOutputStream; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Locale; + +public class MockHttpServletResponse implements HttpServletResponse +{ + public static MockHttpServletResponse forRequest(MockHttpServletRequest req) + { + MockHttpServletResponse response = new MockHttpServletResponse(); + req.newAsyncContext(() -> { + final MockAsyncContext retVal = new MockAsyncContext(); + retVal.response = response; + return retVal; + }); + return response; + } + + public Multimap headers = Multimaps.newListMultimap( + new LinkedHashMap<>(), ArrayList::new + ); + + public final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + private int statusCode; + private String contentType; + + @Override + public void addCookie(Cookie cookie) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean containsHeader(String name) + { + return headers.containsKey(name); + } + + @Override + public String encodeURL(String url) + { + throw new UnsupportedOperationException(); + } + + @Override + public String encodeRedirectURL(String url) + { + throw new UnsupportedOperationException(); + } + + @Override + public String encodeUrl(String url) + { + throw new UnsupportedOperationException(); + } + + @Override + public String encodeRedirectUrl(String url) + { + throw new UnsupportedOperationException(); + } + + @Override + public void sendError(int sc, String msg) + { + throw new UnsupportedOperationException(); + } + + @Override + public void sendError(int sc) + { + throw new UnsupportedOperationException(); + } + + @Override + public void sendRedirect(String location) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setDateHeader(String name, long date) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addDateHeader(String name, long date) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setHeader(String name, String value) + { + headers.put(name, value); + } + + @Override + public void addHeader(String name, String value) + { + headers.put(name, value); + } + + @Override + public void setIntHeader(String name, int value) + { + throw new UnsupportedOperationException(); + } + + @Override + public void addIntHeader(String name, int value) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setStatus(int sc) + { + statusCode = sc; + } + + @Override + public void setStatus(int sc, String sm) + { + throw new UnsupportedOperationException(); + } + + @Override + public int getStatus() + { + return statusCode; + } + + @Nullable + @Override + public String getHeader(String name) + { + final Collection vals = headers.get(name); + if (vals == null || vals.isEmpty()) { + return null; + } + return vals.iterator().next(); + } + + @Override + public Collection getHeaders(String name) + { + return headers.get(name); + } + + @Override + public Collection getHeaderNames() + { + return headers.keySet(); + } + + @Override + public String getCharacterEncoding() + { + throw new UnsupportedOperationException(); + } + + @Override + public String getContentType() + { + return contentType; + } + + @Override + public ServletOutputStream getOutputStream() + { + return new ServletOutputStream() + { + @Override + public boolean isReady() + { + return true; + } + + @Override + public void setWriteListener(WriteListener writeListener) + { + throw new UnsupportedOperationException(); + } + + @Override + public void write(int b) + { + baos.write(b); + } + + @Override + public void write(byte[] b) + { + baos.write(b, 0, b.length); + } + + @Override + public void write(byte[] b, int off, int len) + { + baos.write(b, off, len); + } + }; + } + + @Override + public PrintWriter getWriter() + { + throw new UnsupportedOperationException(); + } + + @Override + public void setCharacterEncoding(String charset) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setContentLength(int len) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setContentLengthLong(long len) + { + throw new UnsupportedOperationException(); + } + + @Override + public void setContentType(String type) + { + this.contentType = type; + } + + @Override + public void setBufferSize(int size) + { + throw new UnsupportedOperationException(); + } + + @Override + public int getBufferSize() + { + throw new UnsupportedOperationException(); + } + + @Override + public void flushBuffer() + { + throw new UnsupportedOperationException(); + } + + @Override + public void resetBuffer() + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isCommitted() + { + return baos.size() > 0; + } + + @Override + public void reset() + { + throw new UnsupportedOperationException(); + } + + @Override + public void setLocale(Locale loc) + { + throw new UnsupportedOperationException(); + } + + @Override + public Locale getLocale() + { + throw new UnsupportedOperationException(); + } +} diff --git a/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java b/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java index 8540a4ca4ca..6facaa54778 100644 --- a/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java +++ b/services/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java @@ -54,7 +54,6 @@ import org.apache.druid.query.Druids; import org.apache.druid.query.MapQueryToolChestWarehouse; import org.apache.druid.query.Query; import org.apache.druid.query.QueryException; -import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.timeseries.TimeseriesQuery; import org.apache.druid.segment.TestHelper; import org.apache.druid.server.initialization.BaseJettyTest; @@ -289,7 +288,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(mockMapper).writeValue(ArgumentMatchers.eq(outputStream), captor.capture()); Assert.assertTrue(captor.getValue() instanceof QueryException); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, ((QueryException) captor.getValue()).getErrorCode()); + Assert.assertEquals("Unknown exception", ((QueryException) captor.getValue()).getErrorCode()); Assert.assertEquals(errorMessage, captor.getValue().getMessage()); Assert.assertEquals(IllegalStateException.class.getName(), ((QueryException) captor.getValue()).getErrorClass()); } @@ -314,7 +313,8 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest new DefaultGenericQueryMetricsFactory(), new AuthenticatorMapper(ImmutableMap.of()), new Properties(), - new ServerConfig() { + new ServerConfig() + { @Override public boolean isShowDetailedJettyErrors() { @@ -333,7 +333,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(mockMapper).writeValue(ArgumentMatchers.eq(outputStream), captor.capture()); Assert.assertTrue(captor.getValue() instanceof QueryException); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, ((QueryException) captor.getValue()).getErrorCode()); + Assert.assertEquals("Unknown exception", ((QueryException) captor.getValue()).getErrorCode()); Assert.assertNull(captor.getValue().getMessage()); Assert.assertNull(((QueryException) captor.getValue()).getErrorClass()); Assert.assertNull(((QueryException) captor.getValue()).getHost()); @@ -359,7 +359,8 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest new DefaultGenericQueryMetricsFactory(), new AuthenticatorMapper(ImmutableMap.of()), new Properties(), - new ServerConfig() { + new ServerConfig() + { @Override public boolean isShowDetailedJettyErrors() { @@ -378,7 +379,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(mockMapper).writeValue(ArgumentMatchers.eq(outputStream), captor.capture()); Assert.assertTrue(captor.getValue() instanceof QueryException); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, ((QueryException) captor.getValue()).getErrorCode()); + Assert.assertEquals("Unknown exception", ((QueryException) captor.getValue()).getErrorCode()); Assert.assertEquals(errorMessage, captor.getValue().getMessage()); Assert.assertNull(((QueryException) captor.getValue()).getErrorClass()); Assert.assertNull(((QueryException) captor.getValue()).getHost()); @@ -412,7 +413,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(mockMapper).writeValue(ArgumentMatchers.eq(outputStream), captor.capture()); Assert.assertTrue(captor.getValue() instanceof QueryException); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, ((QueryException) captor.getValue()).getErrorCode()); + Assert.assertEquals("Unknown exception", ((QueryException) captor.getValue()).getErrorCode()); Assert.assertEquals(errorMessage, captor.getValue().getMessage()); Assert.assertEquals(IOException.class.getName(), ((QueryException) captor.getValue()).getErrorClass()); } @@ -438,7 +439,8 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest new DefaultGenericQueryMetricsFactory(), new AuthenticatorMapper(ImmutableMap.of()), new Properties(), - new ServerConfig() { + new ServerConfig() + { @Override public boolean isShowDetailedJettyErrors() { @@ -457,7 +459,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(mockMapper).writeValue(ArgumentMatchers.eq(outputStream), captor.capture()); Assert.assertTrue(captor.getValue() instanceof QueryException); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, ((QueryException) captor.getValue()).getErrorCode()); + Assert.assertEquals("Unknown exception", ((QueryException) captor.getValue()).getErrorCode()); Assert.assertNull(captor.getValue().getMessage()); Assert.assertNull(((QueryException) captor.getValue()).getErrorClass()); Assert.assertNull(((QueryException) captor.getValue()).getHost()); @@ -484,7 +486,8 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest new DefaultGenericQueryMetricsFactory(), new AuthenticatorMapper(ImmutableMap.of()), new Properties(), - new ServerConfig() { + new ServerConfig() + { @Override public boolean isShowDetailedJettyErrors() { @@ -503,7 +506,7 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); Mockito.verify(mockMapper).writeValue(ArgumentMatchers.eq(outputStream), captor.capture()); Assert.assertTrue(captor.getValue() instanceof QueryException); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, ((QueryException) captor.getValue()).getErrorCode()); + Assert.assertEquals("Unknown exception", ((QueryException) captor.getValue()).getErrorCode()); Assert.assertEquals(errorMessage, captor.getValue().getMessage()); Assert.assertNull(((QueryException) captor.getValue()).getErrorClass()); Assert.assertNull(((QueryException) captor.getValue()).getHost()); @@ -749,10 +752,13 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest final HandlerList handlerList = new HandlerList(); handlerList.setHandlers( - new Handler[]{JettyServerInitUtils.wrapWithDefaultGzipHandler( - root, - ServerConfig.DEFAULT_GZIP_INFLATE_BUFFER_SIZE, - Deflater.DEFAULT_COMPRESSION)} + new Handler[]{ + JettyServerInitUtils.wrapWithDefaultGzipHandler( + root, + ServerConfig.DEFAULT_GZIP_INFLATE_BUFFER_SIZE, + Deflater.DEFAULT_COMPRESSION + ) + } ); server.setHandler(handlerList); } diff --git a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java index ad24274ceb0..62830063d21 100644 --- a/sql/src/main/java/org/apache/druid/sql/DirectStatement.java +++ b/sql/src/main/java/org/apache/druid/sql/DirectStatement.java @@ -24,6 +24,7 @@ import org.apache.calcite.tools.ValidationException; import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.query.QueryException; import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.server.QueryResponse; import org.apache.druid.server.security.ResourceAction; @@ -271,7 +272,7 @@ public class DirectStatement extends AbstractStatement implements Cancelable { if (state == State.CANCELLED) { throw new QueryInterruptedException( - QueryInterruptedException.QUERY_CANCELED, + QueryException.QUERY_CANCELED_ERROR_CODE, StringUtils.format("Query is canceled [%s]", sqlQueryId()), null, null diff --git a/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java b/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java index fb4e4f439b1..60a2b89bf8f 100644 --- a/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java +++ b/sql/src/main/java/org/apache/druid/sql/SqlPlanningException.java @@ -32,11 +32,12 @@ import org.apache.druid.query.BadQueryException; */ public class SqlPlanningException extends BadQueryException { + public enum PlanningError { - SQL_PARSE_ERROR("SQL parse failed", SqlParseException.class.getName()), - VALIDATION_ERROR("Plan validation failed", ValidationException.class.getName()), - UNSUPPORTED_SQL_ERROR("SQL query is unsupported", RelOptPlanner.CannotPlanException.class.getName()); + SQL_PARSE_ERROR(SQL_PARSE_FAILED_ERROR_CODE, SqlParseException.class.getName()), + VALIDATION_ERROR(PLAN_VALIDATION_FAILED_ERROR_CODE, ValidationException.class.getName()), + UNSUPPORTED_SQL_ERROR(SQL_QUERY_UNSUPPORTED_ERROR_CODE, RelOptPlanner.CannotPlanException.class.getName()); private final String errorCode; private final String errorClass; diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index 968a722caa5..c98b091b9c7 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -19,58 +19,48 @@ package org.apache.druid.sql.http; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Preconditions; -import com.google.common.io.CountingOutputStream; import com.google.inject.Inject; import org.apache.calcite.plan.RelOptPlanner; import org.apache.druid.common.exception.SanitizableException; import org.apache.druid.guice.annotations.NativeQuery; import org.apache.druid.guice.annotations.Self; import org.apache.druid.java.util.common.StringUtils; -import org.apache.druid.java.util.common.guava.Yielder; -import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.logger.Logger; -import org.apache.druid.query.BadQueryException; -import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryInterruptedException; -import org.apache.druid.query.QueryTimeoutException; -import org.apache.druid.query.QueryUnsupportedException; import org.apache.druid.server.DruidNode; import org.apache.druid.server.QueryResource; import org.apache.druid.server.QueryResponse; +import org.apache.druid.server.QueryResultPusher; import org.apache.druid.server.ResponseContextConfig; import org.apache.druid.server.initialization.ServerConfig; import org.apache.druid.server.security.Access; import org.apache.druid.server.security.AuthorizationUtils; import org.apache.druid.server.security.AuthorizerMapper; -import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.server.security.ResourceAction; import org.apache.druid.sql.DirectStatement.ResultSet; import org.apache.druid.sql.HttpStatement; -import org.apache.druid.sql.SqlExecutionReporter; import org.apache.druid.sql.SqlLifecycleManager; import org.apache.druid.sql.SqlLifecycleManager.Cancelable; import org.apache.druid.sql.SqlPlanningException; import org.apache.druid.sql.SqlRowTransformer; import org.apache.druid.sql.SqlStatementFactory; -import org.apache.druid.utils.CloseableUtils; import javax.annotation.Nullable; +import javax.servlet.AsyncContext; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; -import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; -import javax.ws.rs.core.StreamingOutput; import java.io.IOException; import java.io.OutputStream; import java.util.List; @@ -84,6 +74,7 @@ public class SqlResource public static final String SQL_HEADER_RESPONSE_HEADER = "X-Druid-SQL-Header-Included"; public static final String SQL_HEADER_VALUE = "yes"; private static final Logger log = new Logger(SqlResource.class); + public static final SqlResourceQueryMetricCounter QUERY_METRIC_COUNTER = new SqlResourceQueryMetricCounter(); private final ObjectMapper jsonMapper; private final AuthorizerMapper authorizerMapper; @@ -117,10 +108,11 @@ public class SqlResource @POST @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) + @Nullable public Response doPost( final SqlQuery sqlQuery, @Context final HttpServletRequest req - ) throws IOException + ) { final HttpStatement stmt = sqlStatementFactory.httpStatement(sqlQuery, req); final String sqlQueryId = stmt.sqlQueryId(); @@ -128,160 +120,29 @@ public class SqlResource try { Thread.currentThread().setName(StringUtils.format("sql[%s]", sqlQueryId)); - ResultSet resultSet = stmt.plan(); - final QueryResponse response = resultSet.run(); - final SqlRowTransformer rowTransformer = resultSet.createRowTransformer(); - final Yielder finalYielder = Yielders.each(response.getResults()); - final Response.ResponseBuilder responseBuilder = Response - .ok( - new StreamingOutput() - { - @Override - public void write(OutputStream output) throws IOException, WebApplicationException - { - Exception e = null; - CountingOutputStream os = new CountingOutputStream(output); - Yielder yielder = finalYielder; + // We use an async context not because we are actually going to run this async, but because we want to delay + // the decision of what the response code should be until we have gotten the first few data points to return. + // Returning a Response object from this point forward requires that object to know the status code, which we + // don't actually know until we are in the accumulator, but if we try to return a Response object from the + // accumulator, we cannot properly stream results back, because the accumulator won't release control of the + // Response until it has consumed the underlying Sequence. + final AsyncContext asyncContext = req.startAsync(); - try (final ResultFormat.Writer writer = sqlQuery.getResultFormat() - .createFormatter(os, jsonMapper)) { - writer.writeResponseStart(); - - if (sqlQuery.includeHeader()) { - writer.writeHeader( - rowTransformer.getRowType(), - sqlQuery.includeTypesHeader(), - sqlQuery.includeSqlTypesHeader() - ); - } - - while (!yielder.isDone()) { - final Object[] row = yielder.get(); - writer.writeRowStart(); - for (int i = 0; i < rowTransformer.getFieldList().size(); i++) { - final Object value = rowTransformer.transform(row, i); - writer.writeRowField(rowTransformer.getFieldList().get(i), value); - } - writer.writeRowEnd(); - yielder = yielder.next(null); - } - - writer.writeResponseEnd(); - } - catch (Exception ex) { - e = ex; - log.error(ex, "Unable to send SQL response [%s]", sqlQueryId); - throw new RuntimeException(ex); - } - finally { - final Exception finalE = e; - - CloseableUtils.closeAll( - yielder, - () -> endLifecycle(stmt, finalE, os.getCount()) - ); - } - } - } - ) - .header(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId); - - if (sqlQuery.includeHeader()) { - responseBuilder.header(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); + try { + QueryResultPusher pusher = new SqlResourceQueryResultPusher(asyncContext, sqlQueryId, stmt, sqlQuery); + pusher.push(); + return null; + } + finally { + asyncContext.complete(); } - - QueryResource.attachResponseContextToHttpResponse( - sqlQueryId, - response.getResponseContext(), - responseBuilder, - jsonMapper, - responseContextConfig, - selfNode - ); - - return responseBuilder.build(); - } - catch (QueryCapacityExceededException cap) { - endLifecycle(stmt, cap, -1); - return buildNonOkResponse(QueryCapacityExceededException.STATUS_CODE, cap, sqlQueryId); - } - catch (QueryUnsupportedException unsupported) { - endLifecycle(stmt, unsupported, -1); - return buildNonOkResponse(QueryUnsupportedException.STATUS_CODE, unsupported, sqlQueryId); - } - catch (QueryTimeoutException timeout) { - endLifecycle(stmt, timeout, -1); - return buildNonOkResponse(QueryTimeoutException.STATUS_CODE, timeout, sqlQueryId); - } - catch (BadQueryException e) { - endLifecycle(stmt, e, -1); - return buildNonOkResponse(BadQueryException.STATUS_CODE, e, sqlQueryId); - } - catch (ForbiddenException e) { - endLifecycleWithoutEmittingMetrics(stmt); - throw (ForbiddenException) serverConfig.getErrorResponseTransformStrategy() - .transformIfNeeded(e); // let ForbiddenExceptionMapper handle this - } - catch (RelOptPlanner.CannotPlanException e) { - endLifecycle(stmt, e, -1); - SqlPlanningException spe = new SqlPlanningException(SqlPlanningException.PlanningError.UNSUPPORTED_SQL_ERROR, - e.getMessage()); - return buildNonOkResponse(BadQueryException.STATUS_CODE, spe, sqlQueryId); - } - // Calcite throws a java.lang.AssertionError which is type error not exception. - // Using throwable will catch all. - catch (Throwable e) { - log.warn(e, "Failed to handle query: %s", sqlQuery); - endLifecycle(stmt, e, -1); - - return buildNonOkResponse( - Status.INTERNAL_SERVER_ERROR.getStatusCode(), - QueryInterruptedException.wrapIfNeeded(e), - sqlQueryId - ); } finally { Thread.currentThread().setName(currThreadName); } } - private void endLifecycleWithoutEmittingMetrics( - HttpStatement stmt - ) - { - stmt.closeQuietly(); - } - - private void endLifecycle( - HttpStatement stmt, - @Nullable final Throwable e, - final long bytesWritten - ) - { - SqlExecutionReporter reporter = stmt.reporter(); - if (e == null) { - reporter.succeeded(bytesWritten); - } else { - reporter.failed(e); - } - stmt.close(); - } - - private Response buildNonOkResponse(int status, SanitizableException e, String sqlQueryId) - throws JsonProcessingException - { - return Response.status(status) - .type(MediaType.APPLICATION_JSON_TYPE) - .entity( - jsonMapper.writeValueAsBytes( - serverConfig.getErrorResponseTransformStrategy().transformIfNeeded(e) - ) - ) - .header(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId) - .build(); - } - @DELETE @Path("{id}") @Produces(MediaType.APPLICATION_JSON) @@ -320,4 +181,187 @@ public class SqlResource return Response.status(Status.FORBIDDEN).build(); } } + + /** + * The SqlResource only generates metrics and doesn't keep track of aggregate counts of successful/failed/interrupted + * queries, so this implementation is effectively just a noop. + */ + private static class SqlResourceQueryMetricCounter implements QueryResource.QueryMetricCounter + { + @Override + public void incrementSuccess() + { + + } + + @Override + public void incrementFailed() + { + + } + + @Override + public void incrementInterrupted() + { + + } + + @Override + public void incrementTimedOut() + { + + } + } + + private class SqlResourceQueryResultPusher extends QueryResultPusher + { + private final String sqlQueryId; + private final HttpStatement stmt; + private final SqlQuery sqlQuery; + + public SqlResourceQueryResultPusher( + AsyncContext asyncContext, + String sqlQueryId, + HttpStatement stmt, + SqlQuery sqlQuery + ) + { + super( + (HttpServletResponse) asyncContext.getResponse(), + SqlResource.this.jsonMapper, + SqlResource.this.responseContextConfig, + SqlResource.this.selfNode, + SqlResource.QUERY_METRIC_COUNTER, + sqlQueryId, + MediaType.APPLICATION_JSON_TYPE + ); + this.sqlQueryId = sqlQueryId; + this.stmt = stmt; + this.sqlQuery = sqlQuery; + } + + @Override + public ResultsWriter start() + { + return new ResultsWriter() + { + private ResultSet thePlan; + + @Override + @Nullable + @SuppressWarnings({"unchecked", "rawtypes"}) + public QueryResponse start(HttpServletResponse response) + { + response.setHeader(SQL_QUERY_ID_RESPONSE_HEADER, sqlQueryId); + + final QueryResponse retVal; + try { + thePlan = stmt.plan(); + retVal = thePlan.run(); + } + catch (RelOptPlanner.CannotPlanException e) { + throw new SqlPlanningException( + SqlPlanningException.PlanningError.UNSUPPORTED_SQL_ERROR, + e.getMessage() + ); + } + // There is a claim that Calcite sometimes throws a java.lang.AssertionError, but we do not have a test that can + // reproduce it checked into the code (the best we have is something that uses mocks to throw an Error, which is + // dubious at best). We keep this just in case, but it might be best to remove it and see where the + // AssertionErrors are coming from and do something to ensure that they don't actually make it out of Calcite + catch (AssertionError e) { + log.warn(e, "AssertionError killed query: %s", sqlQuery); + + // We wrap the exception here so that we get the sanitization. java.lang.AssertionError apparently + // doesn't implement org.apache.druid.common.exception.SanitizableException. + throw new QueryInterruptedException(e); + } + + if (sqlQuery.includeHeader()) { + response.setHeader(SQL_HEADER_RESPONSE_HEADER, SQL_HEADER_VALUE); + } + + return (QueryResponse) retVal; + } + + @Override + public Writer makeWriter(OutputStream out) throws IOException + { + ResultFormat.Writer writer = sqlQuery.getResultFormat().createFormatter(out, jsonMapper); + final SqlRowTransformer rowTransformer = thePlan.createRowTransformer(); + + return new Writer() + { + + @Override + public void writeResponseStart() throws IOException + { + writer.writeResponseStart(); + + if (sqlQuery.includeHeader()) { + writer.writeHeader( + rowTransformer.getRowType(), + sqlQuery.includeTypesHeader(), + sqlQuery.includeSqlTypesHeader() + ); + } + } + + @Override + public void writeRow(Object obj) throws IOException + { + Object[] row = (Object[]) obj; + + writer.writeRowStart(); + for (int i = 0; i < rowTransformer.getFieldList().size(); i++) { + final Object value = rowTransformer.transform(row, i); + writer.writeRowField(rowTransformer.getFieldList().get(i), value); + } + writer.writeRowEnd(); + } + + @Override + public void writeResponseEnd() throws IOException + { + writer.writeResponseEnd(); + } + + @Override + public void close() throws IOException + { + writer.close(); + } + }; + } + + @Override + public void recordSuccess(long numBytes) + { + stmt.reporter().succeeded(numBytes); + } + + @Override + public void recordFailure(Exception e) + { + stmt.reporter().failed(e); + } + + @Override + public void close() + { + stmt.close(); + } + }; + } + + @Override + public void writeException(Exception ex, OutputStream out) throws IOException + { + if (ex instanceof SanitizableException) { + ex = serverConfig.getErrorResponseTransformStrategy().transformIfNeeded((SanitizableException) ex); + } + out.write(jsonMapper.writeValueAsBytes(ex)); + } + + } } diff --git a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java index 42fa66cbaa7..cb9dd0e07aa 100644 --- a/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java +++ b/sql/src/test/java/org/apache/druid/sql/http/SqlResourceTest.java @@ -26,11 +26,11 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.common.util.concurrent.MoreExecutors; import org.apache.calcite.avatica.SqlType; -import org.apache.commons.io.output.NullOutputStream; import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.exception.AllowedRegexErrorResponseTransformStrategy; import org.apache.druid.common.exception.ErrorResponseTransformStrategy; @@ -43,6 +43,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.concurrent.Execs; import org.apache.druid.java.util.common.guava.LazySequence; import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.metrics.StubServiceEmitter; import org.apache.druid.math.expr.ExprMacroTable; @@ -53,7 +54,6 @@ import org.apache.druid.query.Query; import org.apache.druid.query.QueryCapacityExceededException; import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryException; -import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryRunnerFactoryConglomerate; import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryUnsupportedException; @@ -67,6 +67,8 @@ import org.apache.druid.server.QueryStackTests; import org.apache.druid.server.ResponseContextConfig; import org.apache.druid.server.initialization.ServerConfig; import org.apache.druid.server.log.TestRequestLogger; +import org.apache.druid.server.mocks.MockHttpServletRequest; +import org.apache.druid.server.mocks.MockHttpServletResponse; import org.apache.druid.server.scheduling.HiLoQueryLaningStrategy; import org.apache.druid.server.scheduling.ManualQueryPrioritizationStrategy; import org.apache.druid.server.security.Access; @@ -97,7 +99,6 @@ import org.apache.druid.sql.calcite.util.CalciteTestBase; import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.QueryLogHook; import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker; -import org.easymock.EasyMock; import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import org.junit.After; @@ -107,13 +108,12 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import javax.annotation.Nonnull; import javax.servlet.http.HttpServletRequest; -import javax.ws.rs.core.MultivaluedMap; import javax.ws.rs.core.Response; import javax.ws.rs.core.Response.Status; -import javax.ws.rs.core.StreamingOutput; -import java.io.ByteArrayOutputStream; import java.nio.charset.StandardCharsets; +import java.util.AbstractList; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -160,7 +160,7 @@ public class SqlResourceTest extends CalciteTestBase private SpecificSegmentsQuerySegmentWalker walker; private TestRequestLogger testRequestLogger; private SqlResource resource; - private HttpServletRequest req; + private MockHttpServletRequest req; private ListeningExecutorService executorService; private SqlLifecycleManager lifecycleManager; private NativeSqlEngine engine; @@ -223,7 +223,7 @@ public class SqlResourceTest extends CalciteTestBase final DruidOperatorTable operatorTable = CalciteTests.createOperatorTable(); final ExprMacroTable macroTable = CalciteTests.createExprMacroTable(); - req = request(true); + req = request(); testRequestLogger = new TestRequestLogger(); @@ -310,9 +310,9 @@ public class SqlResourceTest extends CalciteTestBase ); } - HttpServletRequest request(boolean ok) + MockHttpServletRequest request() { - return makeExpectedReq(CalciteTests.REGULAR_USER_AUTH_RESULT, ok); + return makeExpectedReq(CalciteTests.REGULAR_USER_AUTH_RESULT); } @After @@ -326,21 +326,19 @@ public class SqlResourceTest extends CalciteTestBase } @Test - public void testUnauthorized() throws Exception + public void testUnauthorized() { - HttpServletRequest testRequest = request(false); - try { - resource.doPost( + postForResponse( createSimpleQueryWithId("id", "select count(*) from forbiddenDatasource"), - testRequest + request() ); Assert.fail("doPost did not throw ForbiddenException for an unauthorized query"); } catch (ForbiddenException e) { // expected } - Assert.assertEquals(0, testRequestLogger.getSqlQueryLogs().size()); + Assert.assertEquals(1, testRequestLogger.getSqlQueryLogs().size()); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -382,10 +380,10 @@ public class SqlResourceTest extends CalciteTestBase mockRespContext.put(ResponseContext.Keys.instance().keyOf("uncoveredIntervalsOverflowed"), "true"); responseContextSupplier.set(mockRespContext); - final Response response = resource.doPost(sqlQuery, makeRegularUserReq()); + final MockHttpServletResponse response = postForResponse(sqlQuery, makeRegularUserReq()); Map responseContext = JSON_MAPPER.readValue( - (String) response.getMetadata().getFirst("X-Druid-Response-Context"), + Iterables.getOnlyElement(response.headers.get("X-Druid-Response-Context")), Map.class ); Assert.assertEquals( @@ -396,9 +394,7 @@ public class SqlResourceTest extends CalciteTestBase responseContext ); - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ((StreamingOutput) response.getEntity()).write(baos); - Object results = JSON_MAPPER.readValue(baos.toByteArray(), Object.class); + Object results = JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class); Assert.assertEquals( ImmutableList.of( @@ -648,7 +644,7 @@ public class SqlResourceTest extends CalciteTestBase } @Test - public void testArrayResultFormatWithErrorAfterFirstRow() throws Exception + public void testArrayResultFormatWithErrorAfterSecondRow() throws Exception { sequenceMapFnSupplier.set(errorAfterSecondRowMapFn()); @@ -723,41 +719,100 @@ public class SqlResourceTest extends CalciteTestBase final String query = "SELECT *, CASE dim2 WHEN '' THEN dim2 END FROM foo LIMIT 2"; final String nullStr = NullHandling.replaceWithDefault() ? "" : null; - Assert.assertEquals( - ImmutableList.of( - EXPECTED_COLUMNS_FOR_RESULT_FORMAT_TESTS, - EXPECTED_TYPES_FOR_RESULT_FORMAT_TESTS, - EXPECTED_SQL_TYPES_FOR_RESULT_FORMAT_TESTS, - Arrays.asList( - "2000-01-01T00:00:00.000Z", - "", - "a", - "[\"a\",\"b\"]", - 1, - 1.0, - 1.0, - "org.apache.druid.hll.VersionOneHyperLogLogCollector", - nullStr - ), - Arrays.asList( - "2000-01-02T00:00:00.000Z", - "10.1", - nullStr, - "[\"b\",\"c\"]", - 1, - 2.0, - 2.0, - "org.apache.druid.hll.VersionOneHyperLogLogCollector", - nullStr - ) - ), - doPost( - new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), - new TypeReference>>() - { - } - ).rhs + final String hllStr = "org.apache.druid.hll.VersionOneHyperLogLogCollector"; + List[] expectedQueryResults = new List[]{ + Arrays.asList("2000-01-01T00:00:00.000Z", "", "a", "[\"a\",\"b\"]", 1, 1.0, 1.0, hllStr, nullStr), + Arrays.asList("2000-01-02T00:00:00.000Z", "10.1", nullStr, "[\"b\",\"c\"]", 1, 2.0, 2.0, hllStr, nullStr) + }; + + MockHttpServletResponse response = postForResponse( + new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), + req.mimic() ); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals("yes", response.getHeader("X-Druid-SQL-Header-Included")); + Assert.assertEquals( + new ArrayList() + { + { + add(EXPECTED_COLUMNS_FOR_RESULT_FORMAT_TESTS); + add(EXPECTED_TYPES_FOR_RESULT_FORMAT_TESTS); + add(EXPECTED_SQL_TYPES_FOR_RESULT_FORMAT_TESTS); + addAll(Arrays.asList(expectedQueryResults)); + } + }, + JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class) + ); + + MockHttpServletResponse responseNoSqlTypesHeader = postForResponse( + new SqlQuery(query, ResultFormat.ARRAY, true, true, false, null, null), + req.mimic() + ); + + Assert.assertEquals(200, responseNoSqlTypesHeader.getStatus()); + Assert.assertEquals("yes", responseNoSqlTypesHeader.getHeader("X-Druid-SQL-Header-Included")); + Assert.assertEquals( + new ArrayList() + { + { + add(EXPECTED_COLUMNS_FOR_RESULT_FORMAT_TESTS); + add(EXPECTED_TYPES_FOR_RESULT_FORMAT_TESTS); + addAll(Arrays.asList(expectedQueryResults)); + } + }, + JSON_MAPPER.readValue(responseNoSqlTypesHeader.baos.toByteArray(), Object.class) + ); + + MockHttpServletResponse responseNoTypesHeader = postForResponse( + new SqlQuery(query, ResultFormat.ARRAY, true, false, true, null, null), + req.mimic() + ); + + Assert.assertEquals(200, responseNoTypesHeader.getStatus()); + Assert.assertEquals("yes", responseNoTypesHeader.getHeader("X-Druid-SQL-Header-Included")); + Assert.assertEquals( + new ArrayList() + { + { + add(EXPECTED_COLUMNS_FOR_RESULT_FORMAT_TESTS); + add(EXPECTED_SQL_TYPES_FOR_RESULT_FORMAT_TESTS); + addAll(Arrays.asList(expectedQueryResults)); + } + }, + JSON_MAPPER.readValue(responseNoTypesHeader.baos.toByteArray(), Object.class) + ); + + MockHttpServletResponse responseNoTypes = postForResponse( + new SqlQuery(query, ResultFormat.ARRAY, true, false, false, null, null), + req.mimic() + ); + + Assert.assertEquals(200, responseNoTypes.getStatus()); + Assert.assertEquals("yes", responseNoTypes.getHeader("X-Druid-SQL-Header-Included")); + Assert.assertEquals( + new ArrayList() + { + { + add(EXPECTED_COLUMNS_FOR_RESULT_FORMAT_TESTS); + addAll(Arrays.asList(expectedQueryResults)); + } + }, + JSON_MAPPER.readValue(responseNoTypes.baos.toByteArray(), Object.class) + ); + + MockHttpServletResponse responseNoHeader = postForResponse( + new SqlQuery(query, ResultFormat.ARRAY, false, false, false, null, null), + req.mimic() + ); + + Assert.assertEquals(200, responseNoHeader.getStatus()); + Assert.assertNull(responseNoHeader.getHeader("X-Druid-SQL-Header-Included")); + Assert.assertEquals( + Arrays.asList(expectedQueryResults), + JSON_MAPPER.readValue(responseNoHeader.baos.toByteArray(), Object.class) + ); + } @Test @@ -765,6 +820,15 @@ public class SqlResourceTest extends CalciteTestBase { // Test a query that returns null header for some of the columns final String query = "SELECT (1, 2) FROM INFORMATION_SCHEMA.COLUMNS LIMIT 1"; + + MockHttpServletResponse response = postForResponse( + new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), + req + ); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals("yes", response.getHeader("X-Druid-SQL-Header-Included")); + Assert.assertEquals( ImmutableList.of( Collections.singletonList("EXPR$0"), @@ -777,12 +841,7 @@ public class SqlResourceTest extends CalciteTestBase ) ) ), - doPost( - new SqlQuery(query, ResultFormat.ARRAY, true, true, true, null, null), - new TypeReference>>() - { - } - ).rhs + JSON_MAPPER.readValue(response.baos.toByteArray(), Object.class) ); } @@ -1313,7 +1372,7 @@ public class SqlResourceTest extends CalciteTestBase ).lhs; Assert.assertNotNull(exception); - Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorCode(), exception.getErrorCode()); + Assert.assertEquals("SQL query is unsupported", exception.getErrorCode()); Assert.assertEquals(PlanningError.UNSUPPORTED_SQL_ERROR.getErrorClass(), exception.getErrorClass()); Assert.assertTrue( exception.getMessage() @@ -1365,7 +1424,7 @@ public class SqlResourceTest extends CalciteTestBase ).lhs; Assert.assertNotNull(exception); - Assert.assertEquals(exception.getErrorCode(), ResourceLimitExceededException.ERROR_CODE); + Assert.assertEquals(exception.getErrorCode(), QueryException.RESOURCE_LIMIT_EXCEEDED_ERROR_CODE); Assert.assertEquals(exception.getErrorClass(), ResourceLimitExceededException.class.getName()); checkSqlRequestLog(false); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); @@ -1396,18 +1455,18 @@ public class SqlResourceTest extends CalciteTestBase ).lhs; Assert.assertNotNull(exception); - Assert.assertEquals(QueryUnsupportedException.ERROR_CODE, exception.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_UNSUPPORTED_ERROR_CODE, exception.getErrorCode()); Assert.assertEquals(QueryUnsupportedException.class.getName(), exception.getErrorClass()); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @Test - public void testErrorResponseReturnSameQueryIdWhenSetInContext() throws Exception + public void testErrorResponseReturnSameQueryIdWhenSetInContext() { String queryId = "id123"; String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final Response response = resource.doPost( + final MockHttpServletResponse response = postForResponse( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1420,18 +1479,18 @@ public class SqlResourceTest extends CalciteTestBase req ); Assert.assertNotEquals(200, response.getStatus()); - final MultivaluedMap headers = response.getMetadata(); - Assert.assertTrue(headers.containsKey(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); - Assert.assertEquals(1, headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER).size()); - Assert.assertEquals(queryId, headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER).get(0)); + Assert.assertEquals( + queryId, + Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)) + ); } @Test - public void testErrorResponseReturnNewQueryIdWhenNotSetInContext() throws Exception + public void testErrorResponseReturnNewQueryIdWhenNotSetInContext() { String errorMessage = "This will be supported in Druid 9999"; failOnExecute(errorMessage); - final Response response = resource.doPost( + final MockHttpServletResponse response = postForResponse( new SqlQuery( "SELECT ANSWER TO LIFE", ResultFormat.OBJECT, @@ -1444,10 +1503,9 @@ public class SqlResourceTest extends CalciteTestBase req ); Assert.assertNotEquals(200, response.getStatus()); - final MultivaluedMap headers = response.getMetadata(); - Assert.assertTrue(headers.containsKey(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER)); - Assert.assertEquals(1, headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER).size()); - Assert.assertFalse(Strings.isNullOrEmpty(headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER).get(0).toString())); + Assert.assertFalse( + Strings.isNullOrEmpty(Iterables.getOnlyElement(response.headers.get(SqlResource.SQL_QUERY_ID_RESPONSE_HEADER))) + ); } @Test @@ -1493,7 +1551,7 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertNotNull(exception); Assert.assertNull(exception.getMessage()); Assert.assertNull(exception.getHost()); - Assert.assertEquals(exception.getErrorCode(), QueryUnsupportedException.ERROR_CODE); + Assert.assertEquals(exception.getErrorCode(), QueryException.QUERY_UNSUPPORTED_ERROR_CODE); Assert.assertNull(exception.getErrorClass()); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -1527,7 +1585,7 @@ public class SqlResourceTest extends CalciteTestBase String errorMessage = "could not assert"; failOnExecute(errorMessage); onExecute = s -> { - throw new Error(errorMessage); + throw new AssertionError(errorMessage); }; final QueryException exception = doPost( new SqlQuery( @@ -1544,7 +1602,7 @@ public class SqlResourceTest extends CalciteTestBase Assert.assertNotNull(exception); Assert.assertNull(exception.getMessage()); Assert.assertNull(exception.getHost()); - Assert.assertEquals(QueryInterruptedException.UNKNOWN_EXCEPTION, exception.getErrorCode()); + Assert.assertEquals("Unknown exception", exception.getErrorCode()); Assert.assertNull(exception.getErrorClass()); Assert.assertTrue(lifecycleManager.getAll("id").isEmpty()); } @@ -1589,7 +1647,7 @@ public class SqlResourceTest extends CalciteTestBase success++; } else { QueryException interruped = result.lhs; - Assert.assertEquals(QueryCapacityExceededException.ERROR_CODE, interruped.getErrorCode()); + Assert.assertEquals(QueryException.QUERY_CAPACITY_EXCEEDED_ERROR_CODE, interruped.getErrorCode()); Assert.assertEquals( QueryCapacityExceededException.makeLaneErrorMessage(HiLoQueryLaningStrategy.LOW, 2), interruped.getMessage() @@ -1625,7 +1683,7 @@ public class SqlResourceTest extends CalciteTestBase ) ).lhs; Assert.assertNotNull(timeoutException); - Assert.assertEquals(timeoutException.getErrorCode(), QueryTimeoutException.ERROR_CODE); + Assert.assertEquals(timeoutException.getErrorCode(), QueryException.QUERY_TIMEOUT_ERROR_CODE); Assert.assertEquals(timeoutException.getErrorClass(), QueryTimeoutException.class.getName()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); } @@ -1639,27 +1697,24 @@ public class SqlResourceTest extends CalciteTestBase validateAndAuthorizeLatchSupplier.set(new NonnullPair<>(validateAndAuthorizeLatch, true)); CountDownLatch planLatch = new CountDownLatch(1); planLatchSupplier.set(new NonnullPair<>(planLatch, false)); - Future future = executorService.submit( - () -> resource.doPost( + Future future = executorService.submit( + () -> postForResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), makeRegularUserReq() ) ); Assert.assertTrue(validateAndAuthorizeLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); Assert.assertTrue(lifecycleAddLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - Response response = resource.cancelQuery(sqlQueryId, mockRequestForCancel()); + Response cancelResponse = resource.cancelQuery(sqlQueryId, makeRequestForCancel()); planLatch.countDown(); - Assert.assertEquals(Status.ACCEPTED.getStatusCode(), response.getStatus()); + Assert.assertEquals(Status.ACCEPTED.getStatusCode(), cancelResponse.getStatus()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); - response = future.get(); - Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); - QueryException exception = JSON_MAPPER.readValue((byte[]) response.getEntity(), QueryException.class); - Assert.assertEquals( - QueryInterruptedException.QUERY_CANCELED, - exception.getErrorCode() - ); + MockHttpServletResponse queryResponse = future.get(); + Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus()); + QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class); + Assert.assertEquals("Query cancelled", exception.getErrorCode()); } @Test @@ -1670,26 +1725,23 @@ public class SqlResourceTest extends CalciteTestBase planLatchSupplier.set(new NonnullPair<>(planLatch, true)); CountDownLatch execLatch = new CountDownLatch(1); executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); - Future future = executorService.submit( - () -> resource.doPost( + Future future = executorService.submit( + () -> postForResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), makeRegularUserReq() ) ); Assert.assertTrue(planLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - Response response = resource.cancelQuery(sqlQueryId, mockRequestForCancel()); + Response cancelResponse = resource.cancelQuery(sqlQueryId, makeRequestForCancel()); execLatch.countDown(); - Assert.assertEquals(Status.ACCEPTED.getStatusCode(), response.getStatus()); + Assert.assertEquals(Status.ACCEPTED.getStatusCode(), cancelResponse.getStatus()); Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty()); - response = future.get(); - Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); - QueryException exception = JSON_MAPPER.readValue((byte[]) response.getEntity(), QueryException.class); - Assert.assertEquals( - QueryInterruptedException.QUERY_CANCELED, - exception.getErrorCode() - ); + MockHttpServletResponse queryResponse = future.get(); + Assert.assertEquals(Status.INTERNAL_SERVER_ERROR.getStatusCode(), queryResponse.getStatus()); + QueryException exception = JSON_MAPPER.readValue(queryResponse.baos.toByteArray(), QueryException.class); + Assert.assertEquals("Query cancelled", exception.getErrorCode()); } @Test @@ -1700,39 +1752,21 @@ public class SqlResourceTest extends CalciteTestBase planLatchSupplier.set(new NonnullPair<>(planLatch, true)); CountDownLatch execLatch = new CountDownLatch(1); executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); - Future future = executorService.submit( - () -> resource.doPost( + Future future = executorService.submit( + () -> postForResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM foo"), makeRegularUserReq() ) ); Assert.assertTrue(planLatch.await(WAIT_TIMEOUT_SECS, TimeUnit.SECONDS)); - Response response = resource.cancelQuery("invalidQuery", mockRequestForCancel()); - Assert.assertEquals(Status.NOT_FOUND.getStatusCode(), response.getStatus()); + Response cancelResponse = resource.cancelQuery("invalidQuery", makeRequestForCancel()); + Assert.assertEquals(Status.NOT_FOUND.getStatusCode(), cancelResponse.getStatus()); Assert.assertFalse(lifecycleManager.getAll(sqlQueryId).isEmpty()); execLatch.countDown(); - response = future.get(); - // The response that we get is the actual object created by the SqlResource. The StreamingOutput object that - // the SqlResource returns at the time of writing has resources opened up (the query is already running) which - // need to be closed. As such, the StreamingOutput needs to actually be called in order to cause that close - // to occur, so we must get the entity out and call `.write(OutputStream)` on it to invoke the code. - try { - ((StreamingOutput) response.getEntity()).write(NullOutputStream.NULL_OUTPUT_STREAM); - } - catch (IllegalStateException e) { - // When we actually attempt to write to the output stream, we seem to run into multi-threading issues likely - // with our test setup. Instead of figuring out how to make the thing work, given that we don't actually - // care about the response, we are going to just ensure that it was the expected exception and ignore it. - // It's possible that this test starts failing suddenly if someone changes the message of the exception, it - // should be safe to just update the expected message here too if that happens. - Assert.assertEquals( - "DefaultQueryMetrics must not be modified from multiple threads. If it is needed to gather dimension or metric information from multiple threads or from an async thread, this information should explicitly be passed between threads (e. g. using Futures), or this DefaultQueryMetrics's ownerThread should be reassigned explicitly", - e.getMessage() - ); - } - Assert.assertEquals(Status.OK.getStatusCode(), response.getStatus()); + MockHttpServletResponse queryResponse = future.get(); + Assert.assertEquals(Status.OK.getStatusCode(), queryResponse.getStatus()); } @Test @@ -1743,21 +1777,21 @@ public class SqlResourceTest extends CalciteTestBase planLatchSupplier.set(new NonnullPair<>(planLatch, true)); CountDownLatch execLatch = new CountDownLatch(1); executeLatchSupplier.set(new NonnullPair<>(execLatch, false)); - Future future = executorService.submit( - () -> resource.doPost( + Future future = executorService.submit( + () -> postForResponse( createSimpleQueryWithId(sqlQueryId, "SELECT DISTINCT dim1 FROM forbiddenDatasource"), makeSuperUserReq() ) ); Assert.assertTrue(planLatch.await(3, TimeUnit.SECONDS)); - Response response = resource.cancelQuery(sqlQueryId, mockRequestForCancel()); - Assert.assertEquals(Status.FORBIDDEN.getStatusCode(), response.getStatus()); + Response cancelResponse = resource.cancelQuery(sqlQueryId, makeRequestForCancel()); + Assert.assertEquals(Status.FORBIDDEN.getStatusCode(), cancelResponse.getStatus()); Assert.assertFalse(lifecycleManager.getAll(sqlQueryId).isEmpty()); execLatch.countDown(); - response = future.get(); - Assert.assertEquals(Status.OK.getStatusCode(), response.getStatus()); + MockHttpServletResponse queryResponse = future.get(); + Assert.assertEquals(Status.OK.getStatusCode(), queryResponse.getStatus()); } @Test @@ -1782,7 +1816,7 @@ public class SqlResourceTest extends CalciteTestBase ) ).lhs; Assert.assertNotNull(queryContextException); - Assert.assertEquals(BadQueryContextException.ERROR_CODE, queryContextException.getErrorCode()); + Assert.assertEquals(QueryException.BAD_QUERY_CONTEXT_ERROR_CODE, queryContextException.getErrorCode()); Assert.assertEquals(BadQueryContextException.ERROR_CLASS, queryContextException.getErrorClass()); Assert.assertTrue(queryContextException.getMessage().contains("2000'")); checkSqlRequestLog(false); @@ -1859,7 +1893,7 @@ public class SqlResourceTest extends CalciteTestBase return doPostRaw(query, req); } - private Pair>> doPost(final SqlQuery query, HttpServletRequest req) + private Pair>> doPost(final SqlQuery query, MockHttpServletRequest req) throws Exception { return doPost(query, req, new TypeReference>>() @@ -1871,7 +1905,7 @@ public class SqlResourceTest extends CalciteTestBase @SuppressWarnings("unchecked") private Pair doPost( final SqlQuery query, - final HttpServletRequest req, + final MockHttpServletRequest req, final TypeReference typeReference ) throws Exception { @@ -1885,83 +1919,48 @@ public class SqlResourceTest extends CalciteTestBase } // Returns either an error or a result. - private Pair doPostRaw(final SqlQuery query, final HttpServletRequest req) throws Exception + private Pair doPostRaw(final SqlQuery query, final MockHttpServletRequest req) + throws Exception { - final Response response = resource.doPost(query, req); - if (response.getStatus() == 200) { - final StreamingOutput output = (StreamingOutput) response.getEntity(); - final ByteArrayOutputStream baos = new ByteArrayOutputStream(); - try { - output.write(baos); - } - catch (Exception ignored) { - // Suppress errors and return the response so far. Similar to what the real web server would do, if it - // started writing a 200 OK and then threw an exception in the middle. - } + MockHttpServletResponse response = postForResponse(query, req); - return Pair.of( - null, - new String(baos.toByteArray(), StandardCharsets.UTF_8) - ); + if (response.getStatus() == 200) { + return Pair.of(null, new String(response.baos.toByteArray(), StandardCharsets.UTF_8)); } else { - return Pair.of( - JSON_MAPPER.readValue((byte[]) response.getEntity(), QueryException.class), - null - ); + return Pair.of(JSON_MAPPER.readValue(response.baos.toByteArray(), QueryException.class), null); } } - private HttpServletRequest makeSuperUserReq() + @Nonnull + private MockHttpServletResponse postForResponse(SqlQuery query, MockHttpServletRequest req) + { + MockHttpServletResponse response = MockHttpServletResponse.forRequest(req); + + Assert.assertNull(resource.doPost(query, req)); + return response; + } + + private MockHttpServletRequest makeSuperUserReq() { return makeExpectedReq(CalciteTests.SUPER_USER_AUTH_RESULT); } - private HttpServletRequest makeRegularUserReq() + private MockHttpServletRequest makeRegularUserReq() { return makeExpectedReq(CalciteTests.REGULAR_USER_AUTH_RESULT); } - private HttpServletRequest makeExpectedReq(AuthenticationResult authenticationResult) + private MockHttpServletRequest makeExpectedReq(AuthenticationResult authenticationResult) { - return makeExpectedReq(authenticationResult, true); - } - - private HttpServletRequest makeExpectedReq(AuthenticationResult authenticationResult, boolean ok) - { - HttpServletRequest req = EasyMock.createStrictMock(HttpServletRequest.class); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(authenticationResult) - .anyTimes(); - EasyMock.expect(req.getRemoteAddr()).andReturn(null).once(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(authenticationResult) - .anyTimes(); - req.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, ok); - EasyMock.expectLastCall().anyTimes(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(authenticationResult) - .anyTimes(); - EasyMock.replay(req); + MockHttpServletRequest req = new MockHttpServletRequest(); + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, authenticationResult); return req; } - private HttpServletRequest mockRequestForCancel() + private MockHttpServletRequest makeRequestForCancel() { - HttpServletRequest req = EasyMock.createNiceMock(HttpServletRequest.class); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(CalciteTests.REGULAR_USER_AUTH_RESULT) - .anyTimes(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)).andReturn(null).anyTimes(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - req.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - EasyMock.expectLastCall().anyTimes(); - EasyMock.replay(req); + MockHttpServletRequest req = new MockHttpServletRequest(); + req.attributes.put(AuthConfig.DRUID_AUTHENTICATION_RESULT, CalciteTests.REGULAR_USER_AUTH_RESULT); return req; } @@ -1969,13 +1968,30 @@ public class SqlResourceTest extends CalciteTestBase { return results -> { final AtomicLong rows = new AtomicLong(); - return results.map(row -> { - if (rows.incrementAndGet() == 3) { - throw new ISE("Oh no!"); - } else { - return row; - } - }); + return results + .flatMap( + row -> Sequences.simple(new AbstractList() + { + @Override + public Object[] get(int index) + { + return row; + } + + @Override + public int size() + { + return 1000; + } + }) + ) + .map(row -> { + if (rows.incrementAndGet() == 3) { + throw new ISE("Oh no!"); + } else { + return row; + } + }); }; }