Throw BadQueryContextException if context params cannot be parsed (#12680)

This commit is contained in:
Tejaswini Bandlamudi 2022-06-24 09:21:25 +05:30 committed by GitHub
parent d29343cbe3
commit 1fc2f6e4b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 94 additions and 7 deletions

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
public class BadQueryContextException extends BadQueryException
{
public static final String ERROR_CODE = "Query context parse failed";
public static final String ERROR_CLASS = BadQueryContextException.class.getName();
public BadQueryContextException(Exception e)
{
this(ERROR_CODE, e.getMessage(), ERROR_CLASS);
}
@JsonCreator
private BadQueryContextException(
@JsonProperty("error") String errorCode,
@JsonProperty("errorMessage") String errorMessage,
@JsonProperty("errorClass") String errorClass
)
{
super(errorCode, errorMessage, errorClass);
}
}

View File

@ -418,10 +418,15 @@ public class QueryContexts
public static <T> long getTimeout(Query<T> query, long defaultTimeout) public static <T> long getTimeout(Query<T> query, long defaultTimeout)
{ {
try {
final long timeout = parseLong(query, TIMEOUT_KEY, defaultTimeout); final long timeout = parseLong(query, TIMEOUT_KEY, defaultTimeout);
Preconditions.checkState(timeout >= 0, "Timeout must be a non negative value, but was [%s]", timeout); Preconditions.checkState(timeout >= 0, "Timeout must be a non negative value, but was [%s]", timeout);
return timeout; return timeout;
} }
catch (NumberFormatException e) {
throw new BadQueryContextException(e);
}
}
public static <T> Query<T> withTimeout(Query<T> query, long timeout) public static <T> Query<T> withTimeout(Query<T> query, long timeout)
{ {

View File

@ -181,6 +181,21 @@ public class QueryContextsTest
QueryContexts.getBrokerServiceName(queryContext); QueryContexts.getBrokerServiceName(queryContext);
} }
@Test
public void testGetTimeout_withNonNumericValue()
{
Map<String, Object> queryContext = new HashMap<>();
queryContext.put(QueryContexts.TIMEOUT_KEY, "2000'");
exception.expect(BadQueryContextException.class);
QueryContexts.getTimeout(new TestQuery(
new TableDataSource("test"),
new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.of("0/100"))),
false,
queryContext
));
}
@Test @Test
public void testDefaultEnableQueryDebugging() public void testDefaultEnableQueryDebugging()
{ {

View File

@ -52,7 +52,6 @@ import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryTimeoutException;
import org.apache.druid.query.QueryToolChest; import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.QueryUnsupportedException; import org.apache.druid.query.QueryUnsupportedException;
import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.query.TruncatedResponseContextException; import org.apache.druid.query.TruncatedResponseContextException;
import org.apache.druid.query.context.ResponseContext; import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.query.context.ResponseContext.Keys; import org.apache.druid.query.context.ResponseContext.Keys;
@ -326,7 +325,7 @@ public class QueryResource implements QueryCountStatsProvider
queryLifecycle.emitLogsAndMetrics(unsupported, req.getRemoteAddr(), -1); queryLifecycle.emitLogsAndMetrics(unsupported, req.getRemoteAddr(), -1);
return ioReaderWriter.getResponseWriter().gotUnsupported(unsupported); return ioReaderWriter.getResponseWriter().gotUnsupported(unsupported);
} }
catch (BadJsonQueryException | ResourceLimitExceededException e) { catch (BadQueryException e) {
interruptedQueryCount.incrementAndGet(); interruptedQueryCount.incrementAndGet();
queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), -1); queryLifecycle.emitLogsAndMetrics(e, req.getRemoteAddr(), -1);
return ioReaderWriter.getResponseWriter().gotBadQuery(e); return ioReaderWriter.getResponseWriter().gotBadQuery(e);

View File

@ -38,7 +38,6 @@ import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryInterruptedException; import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException; import org.apache.druid.query.QueryTimeoutException;
import org.apache.druid.query.QueryUnsupportedException; import org.apache.druid.query.QueryUnsupportedException;
import org.apache.druid.query.ResourceLimitExceededException;
import org.apache.druid.server.initialization.ServerConfig; import org.apache.druid.server.initialization.ServerConfig;
import org.apache.druid.server.security.Access; import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.AuthorizationUtils; import org.apache.druid.server.security.AuthorizationUtils;
@ -196,7 +195,7 @@ public class SqlResource
endLifecycle(sqlQueryId, lifecycle, timeout, remoteAddr, -1); endLifecycle(sqlQueryId, lifecycle, timeout, remoteAddr, -1);
return buildNonOkResponse(QueryTimeoutException.STATUS_CODE, timeout, sqlQueryId); return buildNonOkResponse(QueryTimeoutException.STATUS_CODE, timeout, sqlQueryId);
} }
catch (SqlPlanningException | ResourceLimitExceededException e) { catch (BadQueryException e) {
endLifecycle(sqlQueryId, lifecycle, e, remoteAddr, -1); endLifecycle(sqlQueryId, lifecycle, e, remoteAddr, -1);
return buildNonOkResponse(BadQueryException.STATUS_CODE, e, sqlQueryId); return buildNonOkResponse(BadQueryException.STATUS_CODE, e, sqlQueryId);
} }

View File

@ -46,6 +46,7 @@ import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.emitter.service.ServiceEmitter; import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.math.expr.ExprMacroTable; import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.BadQueryContextException;
import org.apache.druid.query.BaseQuery; import org.apache.druid.query.BaseQuery;
import org.apache.druid.query.DefaultQueryConfig; import org.apache.druid.query.DefaultQueryConfig;
import org.apache.druid.query.Query; import org.apache.druid.query.Query;
@ -1601,6 +1602,30 @@ public class SqlResourceTest extends CalciteTestBase
Assert.assertEquals(Status.OK.getStatusCode(), response.getStatus()); Assert.assertEquals(Status.OK.getStatusCode(), response.getStatus());
} }
@Test
public void testQueryContextException() throws Exception
{
final String sqlQueryId = "badQueryContextTimeout";
Map<String, Object> queryContext = ImmutableMap.of(QueryContexts.TIMEOUT_KEY, "2000'", BaseQuery.SQL_QUERY_ID, sqlQueryId);
final QueryException queryContextException = doPost(
new SqlQuery(
"SELECT 1337",
ResultFormat.OBJECT,
false,
false,
false,
queryContext,
null
)
).lhs;
Assert.assertNotNull(queryContextException);
Assert.assertEquals(BadQueryContextException.ERROR_CODE, queryContextException.getErrorCode());
Assert.assertEquals(BadQueryContextException.ERROR_CLASS, queryContextException.getErrorClass());
Assert.assertTrue(queryContextException.getMessage().contains("For input string: \"2000'\""));
checkSqlRequestLog(false);
Assert.assertTrue(lifecycleManager.getAll(sqlQueryId).isEmpty());
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private void checkSqlRequestLog(boolean success) private void checkSqlRequestLog(boolean success)
{ {