diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java index 0482c720826..1a1acaa008e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java @@ -74,10 +74,8 @@ import org.apache.druid.query.QueryException; import org.apache.druid.rpc.HttpResponseException; import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.server.QueryResponse; -import org.apache.druid.server.security.Access; import org.apache.druid.server.security.AuthenticationResult; import org.apache.druid.server.security.AuthorizationUtils; -import org.apache.druid.server.security.AuthorizerMapper; import org.apache.druid.server.security.ForbiddenException; import org.apache.druid.sql.DirectStatement; import org.apache.druid.sql.HttpStatement; @@ -105,7 +103,6 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.StreamingOutput; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -120,7 +117,6 @@ public class SqlStatementResource public static final String RESULT_FORMAT = "__resultFormat"; private static final Logger log = new Logger(SqlStatementResource.class); private final SqlStatementFactory msqSqlStatementFactory; - private final AuthorizerMapper authorizerMapper; private final ObjectMapper jsonMapper; private final OverlordClient overlordClient; private final StorageConnector storageConnector; @@ -129,14 +125,12 @@ public class SqlStatementResource @Inject public SqlStatementResource( final @MSQ SqlStatementFactory msqSqlStatementFactory, - final AuthorizerMapper authorizerMapper, final ObjectMapper jsonMapper, final OverlordClient overlordClient, final @MultiStageQuery StorageConnector storageConnector ) { this.msqSqlStatementFactory = msqSqlStatementFactory; - this.authorizerMapper = authorizerMapper; this.jsonMapper = jsonMapper; this.overlordClient = overlordClient; this.storageConnector = storageConnector; @@ -151,16 +145,8 @@ public class SqlStatementResource @Produces(MediaType.APPLICATION_JSON) public Response isEnabled(@Context final HttpServletRequest request) { - // All authenticated users are authorized for this API: check an empty resource list. - final Access authResult = AuthorizationUtils.authorizeAllResourceActions( - request, - Collections.emptyList(), - authorizerMapper - ); - - if (!authResult.isAllowed()) { - throw new ForbiddenException(authResult.toString()); - } + // All authenticated users are authorized for this API. + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(request); return Response.ok(ImmutableMap.of("enabled", true)).build(); } @@ -240,14 +226,7 @@ public class SqlStatementResource ) { try { - Access authResult = AuthorizationUtils.authorizeAllResourceActions( - req, - Collections.emptyList(), - authorizerMapper - ); - if (!authResult.isAllowed()) { - throw new ForbiddenException(authResult.toString()); - } + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req); final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); Optional sqlStatementResult = getStatementStatus( @@ -288,14 +267,7 @@ public class SqlStatementResource ) { try { - Access authResult = AuthorizationUtils.authorizeAllResourceActions( - req, - Collections.emptyList(), - authorizerMapper - ); - if (!authResult.isAllowed()) { - throw new ForbiddenException(authResult.toString()); - } + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req); final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); if (page != null && page < 0) { @@ -376,14 +348,7 @@ public class SqlStatementResource { try { - Access authResult = AuthorizationUtils.authorizeAllResourceActions( - req, - Collections.emptyList(), - authorizerMapper - ); - if (!authResult.isAllowed()) { - throw new ForbiddenException(authResult.toString()); - } + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req); final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); Optional sqlStatementResult = getStatementStatus( @@ -776,7 +741,7 @@ public class SqlStatementResource finalStage.getId(), (int) pageInformation.getId(), (int) pageInformation.getId() -// we would always have partition number == worker number + // we would always have partition number == worker number )); } catch (Exception e) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java index ecabbd0fa5c..187601cea0e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlTaskResource.java @@ -60,7 +60,6 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.StreamingOutput; import java.io.IOException; -import java.util.Collections; /** * Endpoint for SQL execution using MSQ tasks. @@ -108,17 +107,7 @@ public class SqlTaskResource @Produces(MediaType.APPLICATION_JSON) public Response doGetEnabled(@Context final HttpServletRequest request) { - // All authenticated users are authorized for this API: check an empty resource list. - final Access authResult = AuthorizationUtils.authorizeAllResourceActions( - request, - Collections.emptyList(), - authorizerMapper - ); - - if (!authResult.isAllowed()) { - throw new ForbiddenException(authResult.toString()); - } - + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(request); return Response.ok(ImmutableMap.of("enabled", true)).build(); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java index 4b1a53b798a..0a20dd12066 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java @@ -72,7 +72,6 @@ public class SqlMSQStatementResourcePostTest extends MSQTestBase { resource = new SqlStatementResource( sqlStatementFactory, - CalciteTests.TEST_AUTHORIZER_MAPPER, objectMapper, indexingServiceClient, localFileStorageConnector @@ -273,7 +272,6 @@ public class SqlMSQStatementResourcePostTest extends MSQTestBase { SqlStatementResource resourceWithDurableStorage = new SqlStatementResource( sqlStatementFactory, - CalciteTests.TEST_AUTHORIZER_MAPPER, objectMapper, indexingServiceClient, NilStorageConnector.getInstance() diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java index d7f04d82777..a603dcb9173 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlStatementResourceTest.java @@ -648,7 +648,6 @@ public class SqlStatementResourceTest extends MSQTestBase setupMocks(overlordClient); resource = new SqlStatementResource( sqlStatementFactory, - CalciteTests.TEST_AUTHORIZER_MAPPER, objectMapper, overlordClient, new LocalFileStorageConnector(tmpFolder.newFolder("local")) diff --git a/server/src/main/java/org/apache/druid/server/security/AuthorizationUtils.java b/server/src/main/java/org/apache/druid/server/security/AuthorizationUtils.java index 2b444982647..2b2afa9a9cd 100644 --- a/server/src/main/java/org/apache/druid/server/security/AuthorizationUtils.java +++ b/server/src/main/java/org/apache/druid/server/security/AuthorizationUtils.java @@ -22,6 +22,7 @@ package org.apache.druid.server.security; import com.google.common.base.Function; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.ISE; import javax.servlet.http.HttpServletRequest; @@ -175,6 +176,23 @@ public class AuthorizationUtils return access; } + /** + * Sets the {@link AuthConfig#DRUID_AUTHORIZATION_CHECKED} attribute in the {@link HttpServletRequest} to true. This method is generally used + * when no {@link ResourceAction} need to be checked for the API. If resources are present, users should call + * {@link AuthorizationUtils#authorizeAllResourceActions(HttpServletRequest, Iterable, AuthorizerMapper)} + */ + public static void setRequestAuthorizationAttributeIfNeeded(final HttpServletRequest request) + { + if (request.getAttribute(AuthConfig.DRUID_ALLOW_UNSECURED_PATH) != null) { + // do nothing since request allows unsecured paths to proceed. Generally, this is used for custom urls. + return; + } + if (request.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED) != null) { + throw DruidException.defensive("Request already had authorization check."); + } + request.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); + } + /** * Filter a collection of resources by applying the resourceActionGenerator to each resource, return an iterable * containing the filtered resources. diff --git a/server/src/test/java/org/apache/druid/server/security/AuthorizationUtilsTest.java b/server/src/test/java/org/apache/druid/server/security/AuthorizationUtilsTest.java index a9d87a4fd79..a66cdde9487 100644 --- a/server/src/test/java/org/apache/druid/server/security/AuthorizationUtilsTest.java +++ b/server/src/test/java/org/apache/druid/server/security/AuthorizationUtilsTest.java @@ -20,6 +20,7 @@ package org.apache.druid.server.security; import com.google.common.base.Function; +import org.apache.druid.server.mocks.MockHttpServletRequest; import org.junit.Assert; import org.junit.Test; @@ -111,4 +112,12 @@ public class AuthorizationUtilsTest ); } } + + @Test + public void testAuthorizationAttributeIfNeeded() + { + MockHttpServletRequest request = new MockHttpServletRequest(); + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(request); + Assert.assertEquals(true, request.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)); + } }