From 0d745ee1207df4990ab92082ec33c75f5777757e Mon Sep 17 00:00:00 2001 From: Parag Jain Date: Thu, 28 Apr 2016 18:50:28 -0500 Subject: [PATCH] Basic authorization support in Druid (#2424) - Introduce `AuthorizationInfo` interface, specific implementations of which would be provided by extensions - If the `druid.auth.enabled` is set to `true` then the `isAuthorized` method of `AuthorizationInfo` will be called to perform authorization checks - `AuthorizationInfo` object will be created in the servlet filters of specific extension and will be passed as a request attribute with attribute name as `AuthConfig.DRUID_AUTH_TOKEN` - As per the scope of this PR, all resources that needs to be secured are divided into 3 types - `DATASOURCE`, `CONFIG` and `STATE`. For any type of resource, possible actions are - `READ` or `WRITE` - Specific ResourceFilters are used to perform auth checks for all endpoints that corresponds to a specific resource type. This prevents duplication of logic and need to inject HttpServletRequest inside each endpoint. For example - `DatasourceResourceFilter` is used for endpoints where the datasource information is present after "datasources" segment in the request Path such as `/druid/coordinator/v1/datasources/`, `/druid/coordinator/v1/metadata/datasources/`, `/druid/v2/datasources/` - `RulesResourceFilter` is used where the datasource information is present after "rules" segment in the request Path such as `/druid/coordinator/v1/rules/` - `TaskResourceFilter` is used for endpoints is used where the datasource information is present after "task" segment in the request Path such as `druid/indexer/v1/task` - `ConfigResourceFilter` is used for endpoints like `/druid/coordinator/v1/config`, `/druid/indexer/v1/worker`, `/druid/worker/v1` etc - `StateResourceFilter` is used for endpoints like `/druid/broker/v1/loadstatus`, `/druid/coordinator/v1/leader`, `/druid/coordinator/v1/loadqueue`, `/druid/coordinator/v1/rules` etc - For endpoints where a list of resources is returned like `/druid/coordinator/v1/datasources`, `/druid/indexer/v1/completeTasks` etc. the list is filtered to return only the resources to which the requested user has access. In these cases, `HttpServletRequest` instance needs to be injected in the endpoint method. Note - JAX-RS specification provides an interface called `SecurityContext`. However, we did not use this but provided our own interface `AuthorizationInfo` mainly because it provides more flexibility. For example, `SecurityContext` has a method called `isUserInRole(String role)` which would be used for auth checks and if used then the mapping of what roles can access what resource needs to be modeled inside Druid either using some convention or some other means which is not very flexible as Druid has dynamic resources like datasources. Fixes #2355 with PR #2424 --- .../overlord/http/OverlordResource.java | 195 ++++++- .../http/security/TaskResourceFilter.java | 123 +++++ .../indexing/worker/http/WorkerResource.java | 9 + .../overlord/http/OverlordResourceTest.java | 499 ++++++------------ .../indexing/overlord/http/OverlordTest.java | 413 +++++++++++++++ .../security/SecurityResourceFilterTest.java | 146 +++++ .../druid/guice/security/DruidAuthModule.java | 44 ++ .../druid/initialization/Initialization.java | 3 + .../EventReceiverFirehoseFactory.java | 2 +- .../io/druid/server/ClientInfoResource.java | 52 +- .../java/io/druid/server/QueryManager.java | 22 +- .../java/io/druid/server/QueryResource.java | 58 +- .../java/io/druid/server/StatusResource.java | 3 + .../io/druid/server/http/BrokerResource.java | 3 + .../CoordinatorDynamicConfigsResource.java | 7 +- .../server/http/CoordinatorResource.java | 3 + .../server/http/DatasourcesResource.java | 34 +- .../druid/server/http/HistoricalResource.java | 3 + .../druid/server/http/IntervalsResource.java | 51 +- .../druid/server/http/InventoryViewUtils.java | 48 +- .../druid/server/http/MetadataResource.java | 103 +++- .../io/druid/server/http/RulesResource.java | 11 +- .../io/druid/server/http/ServersResource.java | 3 + .../io/druid/server/http/TiersResource.java | 3 + .../http/security/AbstractResourceFilter.java | 89 ++++ .../http/security/ConfigResourceFilter.java | 85 +++ .../security/DatasourceResourceFilter.java | 110 ++++ .../http/security/RulesResourceFilter.java | 106 ++++ .../http/security/StateResourceFilter.java | 97 ++++ .../metrics/EventReceiverFirehoseMonitor.java | 2 - .../java/io/druid/server/security/Access.java | 51 ++ .../java/io/druid/server/security/Action.java | 26 + .../io/druid/server/security/AuthConfig.java | 85 +++ .../server/security/AuthorizationInfo.java | 44 ++ .../io/druid/server/security/Resource.java | 69 +++ .../druid/server/security/ResourceType.java | 27 + .../druid/server/ClientInfoResourceTest.java | 3 +- .../io/druid/server/QueryResourceTest.java | 294 ++++++++++- .../server/http/DatasourcesResourceTest.java | 89 +++- .../server/http/IntervalsResourceTest.java | 30 +- .../druid/server/http/RulesResourceTest.java | 4 +- .../security/ResourceFilterTestHelper.java | 245 +++++++++ .../security/SecurityResourceFilterTest.java | 134 +++++ 43 files changed, 2999 insertions(+), 429 deletions(-) create mode 100644 indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java create mode 100644 indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java create mode 100644 indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java create mode 100644 server/src/main/java/io/druid/guice/security/DruidAuthModule.java create mode 100644 server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/StateResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/security/Access.java create mode 100644 server/src/main/java/io/druid/server/security/Action.java create mode 100644 server/src/main/java/io/druid/server/security/AuthConfig.java create mode 100644 server/src/main/java/io/druid/server/security/AuthorizationInfo.java create mode 100644 server/src/main/java/io/druid/server/security/Resource.java create mode 100644 server/src/main/java/io/druid/server/security/ResourceType.java create mode 100644 server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java create mode 100644 server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java diff --git a/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java b/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java index 706036e5e6f..4ef7d5246db 100644 --- a/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java +++ b/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java @@ -22,6 +22,10 @@ package io.druid.indexing.overlord.http; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.base.Function; import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -30,7 +34,9 @@ import com.google.common.collect.Sets; import com.google.common.io.ByteSource; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Inject; +import com.metamx.common.Pair; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.common.config.JacksonConfigManager; @@ -46,8 +52,17 @@ import io.druid.indexing.overlord.TaskRunnerWorkItem; import io.druid.indexing.overlord.TaskStorageQueryAdapter; import io.druid.indexing.overlord.WorkerTaskRunner; import io.druid.indexing.overlord.autoscaling.ScalingStats; +import io.druid.indexing.overlord.http.security.TaskResourceFilter; import io.druid.indexing.overlord.setup.WorkerBehaviorConfig; import io.druid.metadata.EntryExistsException; +import io.druid.server.http.security.ConfigResourceFilter; +import io.druid.server.http.security.StateResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.tasklogs.TaskLogStreamer; import io.druid.timeline.DataSegment; import org.joda.time.DateTime; @@ -63,11 +78,13 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -85,6 +102,7 @@ public class OverlordResource private final TaskLogStreamer taskLogStreamer; private final JacksonConfigManager configManager; private final AuditManager auditManager; + private final AuthConfig authConfig; private AtomicReference workerConfigRef = null; @@ -94,7 +112,8 @@ public class OverlordResource TaskStorageQueryAdapter taskStorageQueryAdapter, TaskLogStreamer taskLogStreamer, JacksonConfigManager configManager, - AuditManager auditManager + AuditManager auditManager, + AuthConfig authConfig ) throws Exception { this.taskMaster = taskMaster; @@ -102,14 +121,35 @@ public class OverlordResource this.taskLogStreamer = taskLogStreamer; this.configManager = configManager; this.auditManager = auditManager; + this.authConfig = authConfig; } @POST @Path("/task") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public Response taskPost(final Task task) + public Response taskPost( + final Task task, + @Context final HttpServletRequest req + ) { + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSource = task.getDataSource(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + return asLeaderWith( taskMaster.getTaskQueue(), new Function() @@ -133,6 +173,7 @@ public class OverlordResource @GET @Path("/leader") + @ResourceFilters(StateResourceFilter.class) @Produces(MediaType.APPLICATION_JSON) public Response getLeader() { @@ -142,6 +183,7 @@ public class OverlordResource @GET @Path("/task/{taskid}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskPayload(@PathParam("taskid") String taskid) { return optionalTaskResponse(taskid, "payload", taskStorageQueryAdapter.getTask(taskid)); @@ -150,6 +192,7 @@ public class OverlordResource @GET @Path("/task/{taskid}/status") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskStatus(@PathParam("taskid") String taskid) { return optionalTaskResponse(taskid, "status", taskStorageQueryAdapter.getStatus(taskid)); @@ -158,6 +201,7 @@ public class OverlordResource @GET @Path("/task/{taskid}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskSegments(@PathParam("taskid") String taskid) { final Set segments = taskStorageQueryAdapter.getInsertedSegments(taskid); @@ -167,6 +211,7 @@ public class OverlordResource @POST @Path("/task/{taskid}/shutdown") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response doShutdown(@PathParam("taskid") final String taskid) { return asLeaderWith( @@ -186,6 +231,7 @@ public class OverlordResource @GET @Path("/worker") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response getWorkerConfig() { if (workerConfigRef == null) { @@ -199,11 +245,12 @@ public class OverlordResource @POST @Path("/worker") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response setWorkerConfig( final WorkerBehaviorConfig workerBehaviorConfig, @HeaderParam(AuditManager.X_DRUID_AUTHOR) @DefaultValue("") final String author, @HeaderParam(AuditManager.X_DRUID_COMMENT) @DefaultValue("") final String comment, - @Context HttpServletRequest req + @Context final HttpServletRequest req ) { if (!configManager.set( @@ -222,6 +269,7 @@ public class OverlordResource @GET @Path("/worker/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response getWorkerConfigHistory( @QueryParam("interval") final String interval, @QueryParam("count") final Integer count @@ -258,6 +306,7 @@ public class OverlordResource @POST @Path("/action") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response doAction(final TaskActionHolder holder) { return asLeaderWith( @@ -292,7 +341,7 @@ public class OverlordResource @GET @Path("/waitingTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getWaitingTasks() + public Response getWaitingTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -302,7 +351,38 @@ public class OverlordResource { // A bit roundabout, but works as a way of figuring out what tasks haven't been handed // off to the runner yet: - final List activeTasks = taskStorageQueryAdapter.getActiveTasks(); + final List allActiveTasks = taskStorageQueryAdapter.getActiveTasks(); + final List activeTasks; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + activeTasks = ImmutableList.copyOf( + Iterables.filter( + allActiveTasks, + new Predicate() + { + @Override + public boolean apply(Task input) + { + Resource resource = new Resource(input.getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } else { + activeTasks = allActiveTasks; + } final Set runnersKnownTasks = Sets.newHashSet( Iterables.transform( taskRunner.getKnownTasks(), @@ -346,7 +426,7 @@ public class OverlordResource @GET @Path("/pendingTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getPendingTasks() + public Response getPendingTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -354,7 +434,13 @@ public class OverlordResource @Override public Collection apply(TaskRunner taskRunner) { - return taskRunner.getPendingTasks(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + return securedTaskRunnerWorkItem(taskRunner.getPendingTasks(), req); + } else { + return taskRunner.getPendingTasks(); + } + } } ); @@ -363,7 +449,7 @@ public class OverlordResource @GET @Path("/runningTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getRunningTasks() + public Response getRunningTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -371,7 +457,12 @@ public class OverlordResource @Override public Collection apply(TaskRunner taskRunner) { - return taskRunner.getRunningTasks(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + return securedTaskRunnerWorkItem(taskRunner.getRunningTasks(), req); + } else { + return taskRunner.getRunningTasks(); + } } } ); @@ -380,10 +471,50 @@ public class OverlordResource @GET @Path("/completeTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getCompleteTasks() + public Response getCompleteTasks(@Context final HttpServletRequest req) { + final List recentlyFinishedTasks; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + recentlyFinishedTasks = ImmutableList.copyOf( + Iterables.filter( + taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(), + new Predicate() + { + @Override + public boolean apply(TaskStatus input) + { + final String taskId = input.getId(); + final Optional optionalTask = taskStorageQueryAdapter.getTask(taskId); + if (!optionalTask.isPresent()) { + throw new WebApplicationException( + Response.serverError().entity( + String.format("No task information found for task with id: [%s]", taskId) + ).build() + ); + } + Resource resource = new Resource(optionalTask.get().getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } else { + recentlyFinishedTasks = taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(); + } + final List completeTasks = Lists.transform( - taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(), + recentlyFinishedTasks, new Function() { @Override @@ -406,6 +537,7 @@ public class OverlordResource @GET @Path("/workers") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getWorkers() { return asLeaderWith( @@ -435,6 +567,7 @@ public class OverlordResource @GET @Path("/scaling") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getScalingState() { // Don't use asLeaderWith, since we want to return 200 instead of 503 when missing an autoscaler. @@ -449,6 +582,7 @@ public class OverlordResource @GET @Path("/task/{taskid}/log") @Produces("text/plain") + @ResourceFilters(TaskResourceFilter.class) public Response doGetLog( @PathParam("taskid") final String taskid, @QueryParam("offset") @DefaultValue("0") final long offset @@ -528,6 +662,45 @@ public class OverlordResource } } + private Collection securedTaskRunnerWorkItem( + Collection collectionToFilter, + HttpServletRequest req + ) + { + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + return Collections2.filter( + collectionToFilter, + new Predicate() + { + @Override + public boolean apply(TaskRunnerWorkItem input) + { + final String taskId = input.getTaskId(); + final Optional optionalTask = taskStorageQueryAdapter.getTask(taskId); + if (!optionalTask.isPresent()) { + throw new WebApplicationException( + Response.serverError().entity( + String.format("No task information found for task with id: [%s]", taskId) + ).build() + ); + } + Resource resource = new Resource(optionalTask.get().getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } + static class TaskResponseObject { private final String id; diff --git a/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java b/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java new file mode 100644 index 00000000000..0866658c08a --- /dev/null +++ b/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java @@ -0,0 +1,123 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http.security; + +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.server.http.security.AbstractResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + +/** + * Use this ResourceFilter when the datasource information is present after "task" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/indexer/v1/task/{taskid}/... + * Note - DO NOT use this filter at MiddleManager resources as TaskStorageQueryAdapter cannot be injected there + */ +public class TaskResourceFilter extends AbstractResourceFilter +{ + private final TaskStorageQueryAdapter taskStorageQueryAdapter; + + @Inject + public TaskResourceFilter(TaskStorageQueryAdapter taskStorageQueryAdapter, AuthConfig authConfig) { + super(authConfig); + this.taskStorageQueryAdapter = taskStorageQueryAdapter; + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String taskId = Preconditions.checkNotNull( + request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("task"); + } + } + ) + 1 + ).getPath() + ); + + Optional taskOptional = taskStorageQueryAdapter.getTask(taskId); + if (!taskOptional.isPresent()) { + throw new WebApplicationException( + Response.status(Response.Status.BAD_REQUEST) + .entity(String.format("Cannot find any task with id: [%s]", taskId)) + .build() + ); + } + final String dataSourceName = Preconditions.checkNotNull(taskOptional.get().getDataSource()); + + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException(Response.status(Response.Status.FORBIDDEN) + .entity( + String.format("Access-Check-Result: %s", authResult.toString()) + ) + .build()); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of("druid/indexer/v1/task/"); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java b/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java index 9bb3bdc44b6..49641462e91 100644 --- a/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java +++ b/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java @@ -27,10 +27,13 @@ import com.google.common.collect.Lists; import com.google.common.io.ByteSource; import com.google.inject.Inject; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.indexing.overlord.TaskRunner; import io.druid.indexing.overlord.TaskRunnerWorkItem; import io.druid.indexing.worker.Worker; import io.druid.indexing.worker.WorkerCuratorCoordinator; +import io.druid.server.http.security.ConfigResourceFilter; +import io.druid.server.http.security.StateResourceFilter; import io.druid.tasklogs.TaskLogStreamer; import javax.ws.rs.DefaultValue; @@ -73,6 +76,7 @@ public class WorkerResource @POST @Path("/disable") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response doDisable() { try { @@ -93,6 +97,7 @@ public class WorkerResource @POST @Path("/enable") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response doEnable() { try { @@ -107,6 +112,7 @@ public class WorkerResource @GET @Path("/enabled") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response isEnabled() { try { @@ -122,6 +128,7 @@ public class WorkerResource @GET @Path("/tasks") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getTasks() { try { @@ -149,6 +156,7 @@ public class WorkerResource @POST @Path("/task/{taskid}/shutdown") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response doShutdown(@PathParam("taskid") String taskid) { try { @@ -164,6 +172,7 @@ public class WorkerResource @GET @Path("/task/{taskid}/log") @Produces("text/plain") + @ResourceFilters(StateResourceFilter.class) public Response doGetLog( @PathParam("taskid") String taskid, @QueryParam("offset") @DefaultValue("0") long offset diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java index 5ef4fd3c8c0..173bd905c37 100644 --- a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java @@ -22,379 +22,226 @@ package io.druid.indexing.overlord.http; import com.google.common.base.Function; import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; -import com.metamx.common.Pair; -import com.metamx.common.guava.CloseQuietly; -import com.metamx.emitter.EmittingLogger; -import com.metamx.emitter.service.ServiceEmitter; -import io.druid.concurrent.Execs; -import io.druid.curator.PotentiallyGzippedCompressionProvider; -import io.druid.curator.discovery.NoopServiceAnnouncer; import io.druid.indexing.common.TaskLocation; import io.druid.indexing.common.TaskStatus; -import io.druid.indexing.common.actions.TaskActionClientFactory; -import io.druid.indexing.common.config.TaskStorageConfig; +import io.druid.indexing.common.TaskToolbox; +import io.druid.indexing.common.actions.TaskActionClient; +import io.druid.indexing.common.task.AbstractTask; import io.druid.indexing.common.task.NoopTask; import io.druid.indexing.common.task.Task; -import io.druid.indexing.overlord.HeapMemoryTaskStorage; -import io.druid.indexing.overlord.TaskLockbox; import io.druid.indexing.overlord.TaskMaster; import io.druid.indexing.overlord.TaskRunner; -import io.druid.indexing.overlord.TaskRunnerFactory; -import io.druid.indexing.overlord.TaskRunnerListener; import io.druid.indexing.overlord.TaskRunnerWorkItem; -import io.druid.indexing.overlord.TaskStorage; import io.druid.indexing.overlord.TaskStorageQueryAdapter; -import io.druid.indexing.overlord.autoscaling.ScalingStats; -import io.druid.indexing.overlord.config.TaskQueueConfig; -import io.druid.server.DruidNode; -import io.druid.server.initialization.IndexerZkConfig; -import io.druid.server.initialization.ZkPathsConfig; -import io.druid.server.metrics.NoopServiceEmitter; -import org.apache.curator.framework.CuratorFramework; -import org.apache.curator.framework.CuratorFrameworkFactory; -import org.apache.curator.retry.RetryOneTime; -import org.apache.curator.test.TestingServer; -import org.apache.curator.test.Timing; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import org.easymock.EasyMock; -import org.joda.time.Period; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; -import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; public class OverlordResourceTest { - private static final TaskLocation TASK_LOCATION = new TaskLocation("dummy", 1000); - - private TestingServer server; - private Timing timing; - private CuratorFramework curator; - private TaskMaster taskMaster; - private TaskLockbox taskLockbox; - private TaskStorage taskStorage; - private TaskActionClientFactory taskActionClientFactory; - private CountDownLatch announcementLatch; - private DruidNode druidNode; private OverlordResource overlordResource; - private CountDownLatch[] taskCompletionCountDownLatches; - private CountDownLatch[] runTaskCountDownLatches; - - private void setupServerAndCurator() throws Exception - { - server = new TestingServer(); - timing = new Timing(); - curator = CuratorFrameworkFactory - .builder() - .connectString(server.getConnectString()) - .sessionTimeoutMs(timing.session()) - .connectionTimeoutMs(timing.connection()) - .retryPolicy(new RetryOneTime(1)) - .compressionProvider(new PotentiallyGzippedCompressionProvider(true)) - .build(); - } - - private void tearDownServerAndCurator() - { - CloseQuietly.close(curator); - CloseQuietly.close(server); - } + private TaskMaster taskMaster; + private TaskStorageQueryAdapter tsqa; + private HttpServletRequest req; + private TaskRunner taskRunner; @Before public void setUp() throws Exception { - taskLockbox = EasyMock.createStrictMock(TaskLockbox.class); - taskLockbox.syncFromStorage(); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.add(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.remove(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); + taskRunner = EasyMock.createMock(TaskRunner.class); + taskMaster = EasyMock.createStrictMock(TaskMaster.class); + tsqa = EasyMock.createStrictMock(TaskStorageQueryAdapter.class); + req = EasyMock.createStrictMock(HttpServletRequest.class); - // for second Noop Task directly added to deep storage. - taskLockbox.add(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.remove(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); + EasyMock.expect(taskMaster.getTaskRunner()).andReturn( + Optional.of(taskRunner) + ).anyTimes(); - taskActionClientFactory = EasyMock.createStrictMock(TaskActionClientFactory.class); - EasyMock.expect(taskActionClientFactory.create(EasyMock.anyObject())) - .andReturn(null).anyTimes(); - EasyMock.replay(taskLockbox, taskActionClientFactory); - - taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(null)); - runTaskCountDownLatches = new CountDownLatch[2]; - runTaskCountDownLatches[0] = new CountDownLatch(1); - runTaskCountDownLatches[1] = new CountDownLatch(1); - taskCompletionCountDownLatches = new CountDownLatch[2]; - taskCompletionCountDownLatches[0] = new CountDownLatch(1); - taskCompletionCountDownLatches[1] = new CountDownLatch(1); - announcementLatch = new CountDownLatch(1); - IndexerZkConfig indexerZkConfig = new IndexerZkConfig(new ZkPathsConfig(), null, null, null, null, null); - setupServerAndCurator(); - curator.start(); - curator.blockUntilConnected(); - curator.create().creatingParentsIfNeeded().forPath(indexerZkConfig.getLeaderLatchPath()); - druidNode = new DruidNode("hey", "what", 1234); - ServiceEmitter serviceEmitter = new NoopServiceEmitter(); - taskMaster = new TaskMaster( - new TaskQueueConfig(null, new Period(1), null, new Period(10)), - taskLockbox, - taskStorage, - taskActionClientFactory, - druidNode, - indexerZkConfig, - new TaskRunnerFactory() - { - @Override - public MockTaskRunner build() - { - return new MockTaskRunner(runTaskCountDownLatches, taskCompletionCountDownLatches); - } - }, - curator, - new NoopServiceAnnouncer() - { - @Override - public void announce(DruidNode node) - { - announcementLatch.countDown(); - } - }, - serviceEmitter - ); - EmittingLogger.registerEmitter(serviceEmitter); - } - - @Test(timeout = 2000L) - public void testOverlordResource() throws Exception - { - // basic task master lifecycle test - taskMaster.start(); - announcementLatch.await(); - while (!taskMaster.isLeading()) { - // I believe the control will never reach here and thread will never sleep but just to be on safe side - Thread.sleep(10); - } - Assert.assertEquals(taskMaster.getLeader(), druidNode.getHostAndPort()); - // Test Overlord resource stuff - overlordResource = new OverlordResource(taskMaster, new TaskStorageQueryAdapter(taskStorage), null, null, null); - Response response = overlordResource.getLeader(); - Assert.assertEquals(druidNode.getHostAndPort(), response.getEntity()); - - final String taskId_0 = "0"; - NoopTask task_0 = new NoopTask(taskId_0, 0, 0, null, null, null); - response = overlordResource.taskPost(task_0); - Assert.assertEquals(200, response.getStatus()); - Assert.assertEquals(ImmutableMap.of("task", taskId_0), response.getEntity()); - - // Duplicate task - should fail - response = overlordResource.taskPost(task_0); - Assert.assertEquals(400, response.getStatus()); - - // Task payload for task_0 should be present in taskStorage - response = overlordResource.getTaskPayload(taskId_0); - Assert.assertEquals(task_0, ((Map) response.getEntity()).get("payload")); - - // Task not present in taskStorage - should fail - response = overlordResource.getTaskPayload("whatever"); - Assert.assertEquals(404, response.getStatus()); - - // Task status of the submitted task should be running - response = overlordResource.getTaskStatus(taskId_0); - Assert.assertEquals(taskId_0, ((Map) response.getEntity()).get("task")); - Assert.assertEquals( - TaskStatus.running(taskId_0).getStatusCode(), - ((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode() + overlordResource = new OverlordResource( + taskMaster, + tsqa, + null, + null, + null, + new AuthConfig(true) ); - // Simulate completion of task_0 - taskCompletionCountDownLatches[Integer.parseInt(taskId_0)].countDown(); - // Wait for taskQueue to handle success status of task_0 - waitForTaskStatus(taskId_0, TaskStatus.Status.SUCCESS); - - // Manually insert task in taskStorage - // Verifies sync from storage - final String taskId_1 = "1"; - NoopTask task_1 = new NoopTask(taskId_1, 0, 0, null, null, null); - taskStorage.insert(task_1, TaskStatus.running(taskId_1)); - // Wait for task runner to run task_1 - runTaskCountDownLatches[Integer.parseInt(taskId_1)].await(); - - response = overlordResource.getRunningTasks(); - // 1 task that was manually inserted should be in running state - Assert.assertEquals(1, (((List) response.getEntity()).size())); - final OverlordResource.TaskResponseObject taskResponseObject = ((List) response - .getEntity()).get(0); - Assert.assertEquals(taskId_1, taskResponseObject.toJson().get("id")); - Assert.assertEquals(TASK_LOCATION, taskResponseObject.toJson().get("location")); - - // Simulate completion of task_1 - taskCompletionCountDownLatches[Integer.parseInt(taskId_1)].countDown(); - // Wait for taskQueue to handle success status of task_1 - waitForTaskStatus(taskId_1, TaskStatus.Status.SUCCESS); - - // should return number of tasks which are not in running state - response = overlordResource.getCompleteTasks(); - Assert.assertEquals(2, (((List) response.getEntity()).size())); - taskMaster.stop(); - Assert.assertFalse(taskMaster.isLeading()); - EasyMock.verify(taskLockbox, taskActionClientFactory); + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN)).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("allow")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ); } - /* Wait until the task with given taskId has the given Task Status - * These method will not timeout until the condition is met so calling method should ensure timeout - * This method also assumes that the task with given taskId is present - * */ - private void waitForTaskStatus(String taskId, TaskStatus.Status status) throws InterruptedException + @Test + public void testSecuredGetWaitingTask() throws Exception { - while (true) { - Response response = overlordResource.getTaskStatus(taskId); - if (status.equals(((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode())) { - break; - } - Thread.sleep(10); - } + EasyMock.expect(tsqa.getActiveTasks()).andReturn( + ImmutableList.of( + getTaskWithIdAndDatasource("id_1", "allow"), + getTaskWithIdAndDatasource("id_2", "allow"), + getTaskWithIdAndDatasource("id_3", "deny"), + getTaskWithIdAndDatasource("id_4", "deny") + ) + ).once(); + + EasyMock.>expect(taskRunner.getKnownTasks()).andReturn( + ImmutableList.of( + new MockTaskRunnerWorkItem("id_1", null), + new MockTaskRunnerWorkItem("id_4", null) + ) + ); + + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + + List responseObjects = (List) overlordResource.getWaitingTasks(req) + .getEntity(); + Assert.assertEquals(1, responseObjects.size()); + Assert.assertEquals("id_2", responseObjects.get(0).toJson().get("id")); + } + + @Test + public void testSecuredGetCompleteTasks() + { + List tasksIds = ImmutableList.of("id_1", "id_2", "id_3"); + EasyMock.expect(tsqa.getRecentlyFinishedTaskStatuses()).andReturn( + Lists.transform( + tasksIds, + new Function() + { + @Override + public TaskStatus apply(String input) + { + return TaskStatus.success(input); + } + } + ) + ).once(); + + EasyMock.expect(tsqa.getTask(tasksIds.get(0))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(0), "deny")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(1))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(1), "allow")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(2))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(2), "allow")) + ).once(); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + + List responseObjects = (List) overlordResource.getCompleteTasks(req) + .getEntity(); + + Assert.assertEquals(2, responseObjects.size()); + Assert.assertEquals(tasksIds.get(1), responseObjects.get(0).toJson().get("id")); + Assert.assertEquals(tasksIds.get(2), responseObjects.get(1).toJson().get("id")); + } + + @Test + public void testSecuredGetRunningTasks() + { + List tasksIds = ImmutableList.of("id_1", "id_2"); + EasyMock.>expect(taskRunner.getRunningTasks()).andReturn( + ImmutableList.of( + new MockTaskRunnerWorkItem(tasksIds.get(0), null), + new MockTaskRunnerWorkItem(tasksIds.get(1), null) + ) + ); + EasyMock.expect(tsqa.getTask(tasksIds.get(0))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(0), "deny")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(1))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(1), "allow")) + ).once(); + + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + + List responseObjects = (List) overlordResource.getRunningTasks(req) + .getEntity(); + + Assert.assertEquals(1, responseObjects.size()); + Assert.assertEquals(tasksIds.get(1), responseObjects.get(0).toJson().get("id")); + } + + @Test + public void testSecuredTaskPost() + { + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + Task task = NoopTask.create(); + Response response = overlordResource.taskPost(task, req); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); } @After - public void tearDown() throws Exception + public void tearDown() { - tearDownServerAndCurator(); + EasyMock.verify(taskRunner, taskMaster, tsqa, req); } - public static class MockTaskRunner implements TaskRunner + private Task getTaskWithIdAndDatasource(String id, String datasource) { - private CountDownLatch[] completionLatches; - private CountDownLatch[] runLatches; - private ConcurrentHashMap taskRunnerWorkItems; - private List runningTasks; - private final AtomicBoolean started = new AtomicBoolean(false); - - public MockTaskRunner(CountDownLatch[] runLatches, CountDownLatch[] completionLatches) + return new AbstractTask(id, datasource, null) { - this.runLatches = runLatches; - this.completionLatches = completionLatches; - this.taskRunnerWorkItems = new ConcurrentHashMap<>(); - this.runningTasks = new ArrayList<>(); - } - - @Override - public List>> restore() - { - return ImmutableList.of(); - } - - public void registerListener(TaskRunnerListener listener, Executor executor) - { - // Overlord doesn't call this method - throw new UnsupportedOperationException(); - } - - @Override - public synchronized ListenableFuture run(final Task task) - { - final String taskId = task.getId(); - ListenableFuture future = MoreExecutors.listeningDecorator( - Execs.singleThreaded( - "noop_test_task_exec_%s" - ) - ).submit( - new Callable() - { - @Override - public TaskStatus call() throws Exception - { - // adding of task to list of runningTasks should be done before count down as - // getRunningTasks may not include the task for which latch has been counted down - // Count down to let know that task is actually running - // this is equivalent of getting process holder to run task in ForkingTaskRunner - runningTasks.add(taskId); - runLatches[Integer.parseInt(taskId)].countDown(); - // Wait for completion count down - completionLatches[Integer.parseInt(taskId)].await(); - taskRunnerWorkItems.remove(taskId); - runningTasks.remove(taskId); - return TaskStatus.success(taskId); - } - } - ); - TaskRunnerWorkItem taskRunnerWorkItem = new TaskRunnerWorkItem(taskId, future) + @Override + public String getType() { - @Override - public TaskLocation getLocation() - { - return TASK_LOCATION; - } - }; - taskRunnerWorkItems.put(taskId, taskRunnerWorkItem); - return future; + return null; + } + + @Override + public boolean isReady(TaskActionClient taskActionClient) throws Exception + { + return false; + } + + @Override + public TaskStatus run(TaskToolbox toolbox) throws Exception + { + return null; + } + }; + } + + private static class MockTaskRunnerWorkItem extends TaskRunnerWorkItem + { + public MockTaskRunnerWorkItem( + String taskId, + ListenableFuture result + ) + { + super(taskId, result); } @Override - public void shutdown(String taskid) {} - - @Override - public synchronized Collection getRunningTasks() + public TaskLocation getLocation() { - List runningTaskList = Lists.transform( - runningTasks, - new Function() - { - @Nullable - @Override - public TaskRunnerWorkItem apply(String input) - { - return taskRunnerWorkItems.get(input); - } - } - ); - return runningTaskList; - } - - @Override - public Collection getPendingTasks() - { - return ImmutableList.of(); - } - - @Override - public Collection getKnownTasks() - { - return taskRunnerWorkItems.values(); - } - - @Override - public Optional getScalingStats() - { - return Optional.absent(); - } - - @Override - public void start() - { - started.set(true); - } - - @Override - public void stop() - { - started.set(false); + return null; } } + } diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java new file mode 100644 index 00000000000..16df2895f32 --- /dev/null +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java @@ -0,0 +1,413 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http; + +import com.google.common.base.Function; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.metamx.common.Pair; +import com.metamx.common.guava.CloseQuietly; +import com.metamx.emitter.EmittingLogger; +import com.metamx.emitter.service.ServiceEmitter; +import io.druid.concurrent.Execs; +import io.druid.curator.PotentiallyGzippedCompressionProvider; +import io.druid.curator.discovery.NoopServiceAnnouncer; +import io.druid.indexing.common.TaskLocation; +import io.druid.indexing.common.TaskStatus; +import io.druid.indexing.common.actions.TaskActionClientFactory; +import io.druid.indexing.common.config.TaskStorageConfig; +import io.druid.indexing.common.task.NoopTask; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.HeapMemoryTaskStorage; +import io.druid.indexing.overlord.TaskLockbox; +import io.druid.indexing.overlord.TaskMaster; +import io.druid.indexing.overlord.TaskRunner; +import io.druid.indexing.overlord.TaskRunnerFactory; +import io.druid.indexing.overlord.TaskRunnerListener; +import io.druid.indexing.overlord.TaskRunnerWorkItem; +import io.druid.indexing.overlord.TaskStorage; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.indexing.overlord.autoscaling.ScalingStats; +import io.druid.indexing.overlord.config.TaskQueueConfig; +import io.druid.server.DruidNode; +import io.druid.server.initialization.IndexerZkConfig; +import io.druid.server.initialization.ZkPathsConfig; +import io.druid.server.metrics.NoopServiceEmitter; +import io.druid.server.security.AuthConfig; +import org.apache.curator.framework.CuratorFramework; +import org.apache.curator.framework.CuratorFrameworkFactory; +import org.apache.curator.retry.RetryOneTime; +import org.apache.curator.test.TestingServer; +import org.apache.curator.test.Timing; +import org.easymock.EasyMock; +import org.joda.time.Period; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Response; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; + +public class OverlordTest +{ + private static final TaskLocation TASK_LOCATION = new TaskLocation("dummy", 1000); + + private TestingServer server; + private Timing timing; + private CuratorFramework curator; + private TaskMaster taskMaster; + private TaskLockbox taskLockbox; + private TaskStorage taskStorage; + private TaskActionClientFactory taskActionClientFactory; + private CountDownLatch announcementLatch; + private DruidNode druidNode; + private OverlordResource overlordResource; + private CountDownLatch[] taskCompletionCountDownLatches; + private CountDownLatch[] runTaskCountDownLatches; + private HttpServletRequest req; + + private void setupServerAndCurator() throws Exception + { + server = new TestingServer(); + timing = new Timing(); + curator = CuratorFrameworkFactory + .builder() + .connectString(server.getConnectString()) + .sessionTimeoutMs(timing.session()) + .connectionTimeoutMs(timing.connection()) + .retryPolicy(new RetryOneTime(1)) + .compressionProvider(new PotentiallyGzippedCompressionProvider(true)) + .build(); + } + + private void tearDownServerAndCurator() + { + CloseQuietly.close(curator); + CloseQuietly.close(server); + } + + @Before + public void setUp() throws Exception + { + req = EasyMock.createStrictMock(HttpServletRequest.class); + taskLockbox = EasyMock.createStrictMock(TaskLockbox.class); + taskLockbox.syncFromStorage(); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.add(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.remove(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + + // for second Noop Task directly added to deep storage. + taskLockbox.add(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.remove(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + + taskActionClientFactory = EasyMock.createStrictMock(TaskActionClientFactory.class); + EasyMock.expect(taskActionClientFactory.create(EasyMock.anyObject())) + .andReturn(null).anyTimes(); + EasyMock.replay(taskLockbox, taskActionClientFactory); + + taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(null)); + runTaskCountDownLatches = new CountDownLatch[2]; + runTaskCountDownLatches[0] = new CountDownLatch(1); + runTaskCountDownLatches[1] = new CountDownLatch(1); + taskCompletionCountDownLatches = new CountDownLatch[2]; + taskCompletionCountDownLatches[0] = new CountDownLatch(1); + taskCompletionCountDownLatches[1] = new CountDownLatch(1); + announcementLatch = new CountDownLatch(1); + IndexerZkConfig indexerZkConfig = new IndexerZkConfig(new ZkPathsConfig(), null, null, null, null, null); + setupServerAndCurator(); + curator.start(); + curator.blockUntilConnected(); + curator.create().creatingParentsIfNeeded().forPath(indexerZkConfig.getLeaderLatchPath()); + druidNode = new DruidNode("hey", "what", 1234); + ServiceEmitter serviceEmitter = new NoopServiceEmitter(); + taskMaster = new TaskMaster( + new TaskQueueConfig(null, new Period(1), null, new Period(10)), + taskLockbox, + taskStorage, + taskActionClientFactory, + druidNode, + indexerZkConfig, + new TaskRunnerFactory() + { + @Override + public MockTaskRunner build() + { + return new MockTaskRunner(runTaskCountDownLatches, taskCompletionCountDownLatches); + } + }, + curator, + new NoopServiceAnnouncer() + { + @Override + public void announce(DruidNode node) + { + announcementLatch.countDown(); + } + }, + serviceEmitter + ); + EmittingLogger.registerEmitter(serviceEmitter); + } + + @Test(timeout = 2000L) + public void testOverlordRun() throws Exception + { + // basic task master lifecycle test + taskMaster.start(); + announcementLatch.await(); + while (!taskMaster.isLeading()) { + // I believe the control will never reach here and thread will never sleep but just to be on safe side + Thread.sleep(10); + } + Assert.assertEquals(taskMaster.getLeader(), druidNode.getHostAndPort()); + // Test Overlord resource stuff + overlordResource = new OverlordResource( + taskMaster, + new TaskStorageQueryAdapter(taskStorage), + null, + null, + null, + new AuthConfig() + ); + Response response = overlordResource.getLeader(); + Assert.assertEquals(druidNode.getHostAndPort(), response.getEntity()); + + final String taskId_0 = "0"; + NoopTask task_0 = new NoopTask(taskId_0, 0, 0, null, null, null); + response = overlordResource.taskPost(task_0, req); + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(ImmutableMap.of("task", taskId_0), response.getEntity()); + + // Duplicate task - should fail + response = overlordResource.taskPost(task_0, req); + Assert.assertEquals(400, response.getStatus()); + + // Task payload for task_0 should be present in taskStorage + response = overlordResource.getTaskPayload(taskId_0); + Assert.assertEquals(task_0, ((Map) response.getEntity()).get("payload")); + + // Task not present in taskStorage - should fail + response = overlordResource.getTaskPayload("whatever"); + Assert.assertEquals(404, response.getStatus()); + + // Task status of the submitted task should be running + response = overlordResource.getTaskStatus(taskId_0); + Assert.assertEquals(taskId_0, ((Map) response.getEntity()).get("task")); + Assert.assertEquals( + TaskStatus.running(taskId_0).getStatusCode(), + ((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode() + ); + + // Simulate completion of task_0 + taskCompletionCountDownLatches[Integer.parseInt(taskId_0)].countDown(); + // Wait for taskQueue to handle success status of task_0 + waitForTaskStatus(taskId_0, TaskStatus.Status.SUCCESS); + + // Manually insert task in taskStorage + // Verifies sync from storage + final String taskId_1 = "1"; + NoopTask task_1 = new NoopTask(taskId_1, 0, 0, null, null, null); + taskStorage.insert(task_1, TaskStatus.running(taskId_1)); + // Wait for task runner to run task_1 + runTaskCountDownLatches[Integer.parseInt(taskId_1)].await(); + + response = overlordResource.getRunningTasks(req); + // 1 task that was manually inserted should be in running state + Assert.assertEquals(1, (((List) response.getEntity()).size())); + final OverlordResource.TaskResponseObject taskResponseObject = ((List) response + .getEntity()).get(0); + Assert.assertEquals(taskId_1, taskResponseObject.toJson().get("id")); + Assert.assertEquals(TASK_LOCATION, taskResponseObject.toJson().get("location")); + + // Simulate completion of task_1 + taskCompletionCountDownLatches[Integer.parseInt(taskId_1)].countDown(); + // Wait for taskQueue to handle success status of task_1 + waitForTaskStatus(taskId_1, TaskStatus.Status.SUCCESS); + + // should return number of tasks which are not in running state + response = overlordResource.getCompleteTasks(req); + Assert.assertEquals(2, (((List) response.getEntity()).size())); + taskMaster.stop(); + Assert.assertFalse(taskMaster.isLeading()); + EasyMock.verify(taskLockbox, taskActionClientFactory); + } + + /* Wait until the task with given taskId has the given Task Status + * These method will not timeout until the condition is met so calling method should ensure timeout + * This method also assumes that the task with given taskId is present + * */ + private void waitForTaskStatus(String taskId, TaskStatus.Status status) throws InterruptedException + { + while (true) { + Response response = overlordResource.getTaskStatus(taskId); + if (status.equals(((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode())) { + break; + } + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception + { + tearDownServerAndCurator(); + } + + public static class MockTaskRunner implements TaskRunner + { + private CountDownLatch[] completionLatches; + private CountDownLatch[] runLatches; + private ConcurrentHashMap taskRunnerWorkItems; + private List runningTasks; + + public MockTaskRunner(CountDownLatch[] runLatches, CountDownLatch[] completionLatches) + { + this.runLatches = runLatches; + this.completionLatches = completionLatches; + this.taskRunnerWorkItems = new ConcurrentHashMap<>(); + this.runningTasks = new ArrayList<>(); + } + + @Override + public List>> restore() + { + return ImmutableList.of(); + } + + @Override + public void registerListener(TaskRunnerListener listener, Executor executor) + { + // Overlord doesn't call this method + throw new UnsupportedOperationException(); + } + + @Override + public void stop() + { + // Do nothing + } + + @Override + public synchronized ListenableFuture run(final Task task) + { + final String taskId = task.getId(); + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded( + "noop_test_task_exec_%s" + ) + ).submit( + new Callable() + { + @Override + public TaskStatus call() throws Exception + { + // adding of task to list of runningTasks should be done before count down as + // getRunningTasks may not include the task for which latch has been counted down + // Count down to let know that task is actually running + // this is equivalent of getting process holder to run task in ForkingTaskRunner + runningTasks.add(taskId); + if (runLatches != null) { + runLatches[Integer.parseInt(taskId)].countDown(); + } + // Wait for completion count down + if (completionLatches != null) { + completionLatches[Integer.parseInt(taskId)].await(); + } + taskRunnerWorkItems.remove(taskId); + runningTasks.remove(taskId); + return TaskStatus.success(taskId); + } + } + ); + TaskRunnerWorkItem taskRunnerWorkItem = new TaskRunnerWorkItem(taskId, future) + { + @Override + public TaskLocation getLocation() + { + return TASK_LOCATION; + } + }; + taskRunnerWorkItems.put(taskId, taskRunnerWorkItem); + return future; + } + + @Override + public void shutdown(String taskid) {} + + @Override + public synchronized Collection getRunningTasks() + { + return Lists.transform( + runningTasks, + new Function() + { + @Nullable + @Override + public TaskRunnerWorkItem apply(String input) + { + return taskRunnerWorkItems.get(input); + } + } + ); + } + + @Override + public Collection getPendingTasks() + { + return ImmutableList.of(); + } + + @Override + public Collection getKnownTasks() + { + return taskRunnerWorkItems.values(); + } + + @Override + public Optional getScalingStats() + { + return Optional.absent(); + } + + @Override + public void start() + { + //Do nothing + } + } +} diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java new file mode 100644 index 00000000000..a0aa98458cf --- /dev/null +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http.security; + +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Injector; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.indexing.common.task.NoopTask; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.indexing.overlord.http.OverlordResource; +import io.druid.indexing.worker.http.WorkerResource; +import io.druid.server.http.security.AbstractResourceFilter; +import io.druid.server.http.security.ResourceFilterTestHelper; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import java.util.Collection; + +@RunWith(Parameterized.class) +public class SecurityResourceFilterTest extends ResourceFilterTestHelper +{ + + @Parameterized.Parameters + public static Collection data() + { + return ImmutableList.copyOf( + Iterables.concat( + getRequestPaths(OverlordResource.class, ImmutableList.>of(TaskStorageQueryAdapter.class)), + getRequestPaths(WorkerResource.class) + ) + ); + } + + private final String requestPath; + private final String requestMethod; + private final ResourceFilter resourceFilter; + private final Injector injector; + private final Task noopTask = new NoopTask(null, 0, 0, null, null, null); + + private static boolean mockedOnce; + private TaskStorageQueryAdapter tsqa; + + public SecurityResourceFilterTest( + String requestPath, + String requestMethod, + ResourceFilter resourceFilter, + Injector injector + ) + { + this.requestPath = requestPath; + this.requestMethod = requestMethod; + this.resourceFilter = resourceFilter; + this.injector = injector; + } + + @Before + public void setUp() throws Exception + { + if (resourceFilter instanceof TaskResourceFilter && !mockedOnce) { + // Since we are creating the mocked tsqa object only once and getting that object from Guice here therefore + // if the mockedOnce check is not done then we will call EasyMock.expect and EasyMock.replay on the mocked object + // multiple times and it will throw exceptions + tsqa = injector.getInstance(TaskStorageQueryAdapter.class); + EasyMock.expect(tsqa.getTask(EasyMock.anyString())).andReturn(Optional.of(noopTask)).anyTimes(); + EasyMock.replay(tsqa); + mockedOnce = true; + } + setUp(resourceFilter); + } + + @Test + public void testDatasourcesResourcesFilteringAccess() + { + setUpMockExpectations(requestPath, true, requestMethod); + EasyMock.expect(request.getEntity(Task.class)).andReturn(noopTask).anyTimes(); + // As request object is a strict mock the ordering of expected calls matters + // therefore adding the expectation below again as getEntity is called before getMethod + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + resourceFilter.getRequestFilter().filter(request); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + } + + @Test(expected = WebApplicationException.class) + public void testDatasourcesResourcesFilteringNoAccess() + { + setUpMockExpectations(requestPath, false, requestMethod); + EasyMock.expect(request.getEntity(Task.class)).andReturn(noopTask).anyTimes(); + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + try { + resourceFilter.getRequestFilter().filter(request); + } + catch (WebApplicationException e) { + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), e.getResponse().getStatus()); + throw e; + } + } + + @Test + public void testDatasourcesResourcesFilteringBadPath() + { + final String badRequestPath = requestPath.replaceAll("\\w+", "droid"); + EasyMock.expect(request.getPath()).andReturn(badRequestPath).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertFalse(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(badRequestPath)); + } + + @After + public void tearDown() + { + EasyMock.verify(req, request, authorizationInfo); + if (tsqa != null) { + EasyMock.verify(tsqa); + } + } + +} diff --git a/server/src/main/java/io/druid/guice/security/DruidAuthModule.java b/server/src/main/java/io/druid/guice/security/DruidAuthModule.java new file mode 100644 index 00000000000..e89c8ca2367 --- /dev/null +++ b/server/src/main/java/io/druid/guice/security/DruidAuthModule.java @@ -0,0 +1,44 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.guice.security; + +import com.fasterxml.jackson.databind.Module; +import com.google.inject.Binder; +import io.druid.guice.JsonConfigProvider; +import io.druid.initialization.DruidModule; +import io.druid.server.security.AuthConfig; + +import java.util.Collections; +import java.util.List; + +public class DruidAuthModule implements DruidModule +{ + @Override + public List getJacksonModules() + { + return Collections.emptyList(); + } + + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, "druid.auth", AuthConfig.class); + } +} diff --git a/server/src/main/java/io/druid/initialization/Initialization.java b/server/src/main/java/io/druid/initialization/Initialization.java index 0bfc8c0c7bd..0752036575e 100644 --- a/server/src/main/java/io/druid/initialization/Initialization.java +++ b/server/src/main/java/io/druid/initialization/Initialization.java @@ -57,6 +57,7 @@ import io.druid.guice.annotations.Client; import io.druid.guice.annotations.Json; import io.druid.guice.annotations.Smile; import io.druid.guice.http.HttpClientModule; +import io.druid.guice.security.DruidAuthModule; import io.druid.metadata.storage.derby.DerbyMetadataStorageDruidModule; import io.druid.server.initialization.EmitterModule; import io.druid.server.initialization.jetty.JettyServerModule; @@ -318,7 +319,9 @@ public class Initialization { final ModuleList defaultModules = new ModuleList(baseInjector); defaultModules.addModules( + // New modules should be added after Log4jShutterDownerModule new Log4jShutterDownerModule(), + new DruidAuthModule(), new LifecycleModule(), EmitterModule.class, HttpClientModule.global(), diff --git a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java index 415c5c9101b..ff6cb39e8e2 100644 --- a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java +++ b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java @@ -143,7 +143,7 @@ public class EventReceiverFirehoseFactory implements FirehoseFactory(bufferSize); + this.buffer = new ArrayBlockingQueue<>(bufferSize); this.parser = parser; } diff --git a/server/src/main/java/io/druid/server/ClientInfoResource.java b/server/src/main/java/io/druid/server/ClientInfoResource.java index e3a653fe371..9b800b891d7 100644 --- a/server/src/main/java/io/druid/server/ClientInfoResource.java +++ b/server/src/main/java/io/druid/server/ClientInfoResource.java @@ -19,13 +19,17 @@ package io.druid.server; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.inject.Inject; +import com.metamx.common.Pair; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.FilteredServerInventoryView; @@ -34,6 +38,13 @@ import io.druid.client.TimelineServerView; import io.druid.client.selector.ServerSelector; import io.druid.query.TableDataSource; import io.druid.query.metadata.SegmentMetadataQueryConfig; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.timeline.DataSegment; import io.druid.timeline.TimelineLookup; import io.druid.timeline.TimelineObjectHolder; @@ -41,14 +52,17 @@ import io.druid.timeline.partition.PartitionHolder; import org.joda.time.DateTime; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -67,18 +81,21 @@ public class ClientInfoResource private FilteredServerInventoryView serverInventoryView; private TimelineServerView timelineServerView; private SegmentMetadataQueryConfig segmentMetadataQueryConfig; + private final AuthConfig authConfig; @Inject public ClientInfoResource( FilteredServerInventoryView serverInventoryView, TimelineServerView timelineServerView, - SegmentMetadataQueryConfig segmentMetadataQueryConfig + SegmentMetadataQueryConfig segmentMetadataQueryConfig, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; this.timelineServerView = timelineServerView; this.segmentMetadataQueryConfig = (segmentMetadataQueryConfig == null) ? new SegmentMetadataQueryConfig() : segmentMetadataQueryConfig; + this.authConfig = authConfig; } private Map> getSegmentsForDatasources() @@ -98,14 +115,41 @@ public class ClientInfoResource @GET @Produces(MediaType.APPLICATION_JSON) - public Iterable getDataSources() + public Iterable getDataSources(@Context final HttpServletRequest request) { - return getSegmentsForDatasources().keySet(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) request.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + return Collections2.filter( + getSegmentsForDatasources().keySet(), + new Predicate() + { + @Override + public boolean apply(String input) + { + Resource resource = new Resource(input, ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } else { + return getSegmentsForDatasources().keySet(); + } } @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Map getDatasource( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval, @@ -193,6 +237,7 @@ public class ClientInfoResource @GET @Path("/{dataSourceName}/dimensions") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Iterable getDatasourceDimensions( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval @@ -225,6 +270,7 @@ public class ClientInfoResource @GET @Path("/{dataSourceName}/metrics") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Iterable getDatasourceMetrics( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval diff --git a/server/src/main/java/io/druid/server/QueryManager.java b/server/src/main/java/io/druid/server/QueryManager.java index 3e2b3b51079..49252c8c0ad 100644 --- a/server/src/main/java/io/druid/server/QueryManager.java +++ b/server/src/main/java/io/druid/server/QueryManager.java @@ -27,20 +27,28 @@ import com.google.common.util.concurrent.MoreExecutors; import io.druid.query.Query; import io.druid.query.QueryWatcher; +import java.util.List; import java.util.Set; public class QueryManager implements QueryWatcher { - final SetMultimap queries; + + private final SetMultimap queries; + private final SetMultimap queryDatasources; public QueryManager() { this.queries = Multimaps.synchronizedSetMultimap( HashMultimap.create() ); + this.queryDatasources = Multimaps.synchronizedSetMultimap( + HashMultimap.create() + ); } - public boolean cancelQuery(String id) { + public boolean cancelQuery(String id) + { + queryDatasources.removeAll(id); Set futures = queries.removeAll(id); boolean success = true; for (ListenableFuture future : futures) { @@ -52,7 +60,9 @@ public class QueryManager implements QueryWatcher public void registerQuery(Query query, final ListenableFuture future) { final String id = query.getId(); + final List datasources = query.getDataSource().getNames(); queries.put(id, future); + queryDatasources.putAll(id, datasources); future.addListener( new Runnable() { @@ -60,9 +70,17 @@ public class QueryManager implements QueryWatcher public void run() { queries.remove(id, future); + for (String datasource : datasources) { + queryDatasources.remove(id, datasource); + } } }, MoreExecutors.sameThreadExecutor() ); } + + public Set getQueryDatasources(final String queryId) + { + return queryDatasources.get(queryId); + } } diff --git a/server/src/main/java/io/druid/server/QueryResource.java b/server/src/main/java/io/druid/server/QueryResource.java index 0b9ac2b0fa5..63e37e338da 100644 --- a/server/src/main/java/io/druid/server/QueryResource.java +++ b/server/src/main/java/io/druid/server/QueryResource.java @@ -22,11 +22,13 @@ package io.druid.server; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.collect.MapMaker; import com.google.common.io.CountingOutputStream; import com.google.inject.Inject; +import com.metamx.common.ISE; import com.metamx.common.guava.Sequence; import com.metamx.common.guava.Sequences; import com.metamx.common.guava.Yielder; @@ -42,6 +44,12 @@ import io.druid.query.QueryInterruptedException; import io.druid.query.QuerySegmentWalker; import io.druid.server.initialization.ServerConfig; import io.druid.server.log.RequestLogger; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import org.joda.time.DateTime; import javax.servlet.http.HttpServletRequest; @@ -61,6 +69,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Map; +import java.util.Set; import java.util.UUID; /** @@ -81,6 +90,7 @@ public class QueryResource private final ServiceEmitter emitter; private final RequestLogger requestLogger; private final QueryManager queryManager; + private final AuthConfig authConfig; @Inject public QueryResource( @@ -90,7 +100,8 @@ public class QueryResource QuerySegmentWalker texasRanger, ServiceEmitter emitter, RequestLogger requestLogger, - QueryManager queryManager + QueryManager queryManager, + AuthConfig authConfig ) { this.config = config; @@ -100,16 +111,39 @@ public class QueryResource this.emitter = emitter; this.requestLogger = requestLogger; this.queryManager = queryManager; + this.authConfig = authConfig; } @DELETE @Path("{id}") @Produces(MediaType.APPLICATION_JSON) - public Response getServer(@PathParam("id") String queryId) + public Response getServer(@PathParam("id") String queryId, @Context final HttpServletRequest req) { if (log.isDebugEnabled()) { log.debug("Received cancel request for query [%s]", queryId); } + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + Set datasources = queryManager.getQueryDatasources(queryId); + if (datasources == null) { + log.warn("QueryId [%s] not registered with QueryManager, cannot cancel", queryId); + } else { + for (String dataSource : datasources) { + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + } + } queryManager.cancelQuery(queryId); return Response.status(Response.Status.ACCEPTED).build(); } @@ -120,7 +154,7 @@ public class QueryResource public Response doPost( InputStream in, @QueryParam("pretty") String pretty, - @Context final HttpServletRequest req // used only to get request content-type and remote address + @Context final HttpServletRequest req // used to get request content-type, remote address and AuthorizationInfo ) throws IOException { final long start = System.currentTimeMillis(); @@ -160,6 +194,24 @@ public class QueryResource log.debug("Got query [%s]", query); } + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + if (authorizationInfo != null) { + for (String dataSource : query.getDataSource().getNames()) { + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.READ + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + } else { + throw new ISE("WTF?! Security is enabled but no authorization info found in the request"); + } + } + final Map responseContext = new MapMaker().makeMap(); final Sequence res = query.run(texasRanger, responseContext); final Sequence results; diff --git a/server/src/main/java/io/druid/server/StatusResource.java b/server/src/main/java/io/druid/server/StatusResource.java index f5012daafec..edbd65b4fdb 100644 --- a/server/src/main/java/io/druid/server/StatusResource.java +++ b/server/src/main/java/io/druid/server/StatusResource.java @@ -21,8 +21,10 @@ package io.druid.server; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.initialization.DruidModule; import io.druid.initialization.Initialization; +import io.druid.server.http.security.StateResourceFilter; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -35,6 +37,7 @@ import java.util.List; /** */ @Path("/status") +@ResourceFilters(StateResourceFilter.class) public class StatusResource { @GET diff --git a/server/src/main/java/io/druid/server/http/BrokerResource.java b/server/src/main/java/io/druid/server/http/BrokerResource.java index 7e9701a39b7..7adc968e402 100644 --- a/server/src/main/java/io/druid/server/http/BrokerResource.java +++ b/server/src/main/java/io/druid/server/http/BrokerResource.java @@ -21,7 +21,9 @@ package io.druid.server.http; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.BrokerServerView; +import io.druid.server.http.security.StateResourceFilter; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -30,6 +32,7 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; @Path("/druid/broker/v1") +@ResourceFilters(StateResourceFilter.class) public class BrokerResource { private final BrokerServerView brokerServerView; diff --git a/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java b/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java index 0d955b915bf..c4e572a15a5 100644 --- a/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java +++ b/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java @@ -19,15 +19,15 @@ package io.druid.server.http; +import com.google.common.collect.ImmutableMap; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.common.config.JacksonConfigManager; import io.druid.server.coordinator.CoordinatorDynamicConfig; - +import io.druid.server.http.security.ConfigResourceFilter; import org.joda.time.Interval; -import com.google.common.collect.ImmutableMap; - import javax.inject.Inject; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; @@ -45,6 +45,7 @@ import javax.ws.rs.core.Response; /** */ @Path("/druid/coordinator/v1/config") +@ResourceFilters(ConfigResourceFilter.class) public class CoordinatorDynamicConfigsResource { private final JacksonConfigManager manager; diff --git a/server/src/main/java/io/druid/server/http/CoordinatorResource.java b/server/src/main/java/io/druid/server/http/CoordinatorResource.java index ac13e9ec22f..20f6805dae1 100644 --- a/server/src/main/java/io/druid/server/http/CoordinatorResource.java +++ b/server/src/main/java/io/druid/server/http/CoordinatorResource.java @@ -24,8 +24,10 @@ import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.server.coordinator.DruidCoordinator; import io.druid.server.coordinator.LoadQueuePeon; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import javax.ws.rs.GET; @@ -38,6 +40,7 @@ import javax.ws.rs.core.Response; /** */ @Path("/druid/coordinator/v1") +@ResourceFilters(StateResourceFilter.class) public class CoordinatorResource { private final DruidCoordinator coordinator; diff --git a/server/src/main/java/io/druid/server/http/DatasourcesResource.java b/server/src/main/java/io/druid/server/http/DatasourcesResource.java index 8aa035f9669..274e03492c5 100644 --- a/server/src/main/java/io/druid/server/http/DatasourcesResource.java +++ b/server/src/main/java/io/druid/server/http/DatasourcesResource.java @@ -31,6 +31,7 @@ import com.metamx.common.Pair; import com.metamx.common.guava.Comparators; import com.metamx.common.guava.FunctionalIterable; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.CoordinatorServerView; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; @@ -39,6 +40,9 @@ import io.druid.client.SegmentLoadInfo; import io.druid.client.indexing.IndexingServiceClient; import io.druid.metadata.MetadataSegmentManager; import io.druid.query.TableDataSource; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; import io.druid.timeline.DataSegment; import io.druid.timeline.TimelineLookup; import io.druid.timeline.TimelineObjectHolder; @@ -47,6 +51,7 @@ import org.joda.time.DateTime; import org.joda.time.Interval; import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.GET; @@ -55,6 +60,7 @@ import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Comparator; @@ -73,28 +79,38 @@ public class DatasourcesResource private final CoordinatorServerView serverInventoryView; private final MetadataSegmentManager databaseSegmentManager; private final IndexingServiceClient indexingServiceClient; + private final AuthConfig authConfig; @Inject public DatasourcesResource( CoordinatorServerView serverInventoryView, MetadataSegmentManager databaseSegmentManager, - @Nullable IndexingServiceClient indexingServiceClient + @Nullable IndexingServiceClient indexingServiceClient, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; this.databaseSegmentManager = databaseSegmentManager; this.indexingServiceClient = indexingServiceClient; + this.authConfig = authConfig; } @GET @Produces(MediaType.APPLICATION_JSON) public Response getQueryableDataSources( @QueryParam("full") String full, - @QueryParam("simple") String simple + @QueryParam("simple") String simple, + @Context final HttpServletRequest req ) { Response.ResponseBuilder builder = Response.ok(); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); + if (full != null) { return builder.entity(datasources).build(); } else if (simple != null) { @@ -135,12 +151,14 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getTheDataSource( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("full") final String full ) { DruidDataSource dataSource = getDataSource(dataSourceName); + if (dataSource == null) { return Response.noContent().build(); } @@ -155,6 +173,7 @@ public class DatasourcesResource @POST @Path("/{dataSourceName}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response enableDataSource( @PathParam("dataSourceName") final String dataSourceName ) @@ -175,6 +194,7 @@ public class DatasourcesResource @DELETE @Deprecated @Path("/{dataSourceName}") + @ResourceFilters(DatasourceResourceFilter.class) @Produces(MediaType.APPLICATION_JSON) public Response deleteDataSource( @PathParam("dataSourceName") final String dataSourceName, @@ -253,6 +273,7 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}/intervals") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceIntervals( @PathParam("dataSourceName") String dataSourceName, @QueryParam("simple") String simple, @@ -313,6 +334,7 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}/intervals/{interval}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSpecificInterval( @PathParam("dataSourceName") String dataSourceName, @PathParam("interval") String interval, @@ -380,6 +402,7 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full @@ -413,6 +436,7 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}/segments/{segmentId}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -436,6 +460,7 @@ public class DatasourcesResource @DELETE @Path("/{dataSourceName}/segments/{segmentId}") + @ResourceFilters(DatasourceResourceFilter.class) public Response deleteDatasourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -451,6 +476,7 @@ public class DatasourcesResource @POST @Path("/{dataSourceName}/segments/{segmentId}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response enableDatasourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -466,6 +492,7 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}/tiers") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceTiers( @PathParam("dataSourceName") String dataSourceName ) @@ -624,6 +651,7 @@ public class DatasourcesResource @GET @Path("/{dataSourceName}/intervals/{interval}/serverview") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSpecificInterval( @PathParam("dataSourceName") String dataSourceName, @PathParam("interval") String interval, diff --git a/server/src/main/java/io/druid/server/http/HistoricalResource.java b/server/src/main/java/io/druid/server/http/HistoricalResource.java index 4680cf29c6c..bc77ce0fc05 100644 --- a/server/src/main/java/io/druid/server/http/HistoricalResource.java +++ b/server/src/main/java/io/druid/server/http/HistoricalResource.java @@ -20,7 +20,9 @@ package io.druid.server.http; import com.google.common.collect.ImmutableMap; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.server.coordination.ZkCoordinator; +import io.druid.server.http.security.StateResourceFilter; import javax.inject.Inject; import javax.ws.rs.GET; @@ -30,6 +32,7 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; @Path("/druid/historical/v1") +@ResourceFilters(StateResourceFilter.class) public class HistoricalResource { private final ZkCoordinator coordinator; diff --git a/server/src/main/java/io/druid/server/http/IntervalsResource.java b/server/src/main/java/io/druid/server/http/IntervalsResource.java index 103330fc50a..29c8a1f4f86 100644 --- a/server/src/main/java/io/druid/server/http/IntervalsResource.java +++ b/server/src/main/java/io/druid/server/http/IntervalsResource.java @@ -25,14 +25,18 @@ import com.metamx.common.MapUtils; import com.metamx.common.guava.Comparators; import io.druid.client.DruidDataSource; import io.druid.client.InventoryView; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; import io.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Comparator; @@ -45,35 +49,43 @@ import java.util.Set; public class IntervalsResource { private final InventoryView serverInventoryView; + private final AuthConfig authConfig; @Inject public IntervalsResource( - InventoryView serverInventoryView + InventoryView serverInventoryView, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; + this.authConfig = authConfig; } @GET @Produces(MediaType.APPLICATION_JSON) - public Response getIntervals() + public Response getIntervals(@Context final HttpServletRequest req) { - final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); - final Map>> retVal = Maps.newTreeMap(comparator); - for (DruidDataSource dataSource : datasources) { - for (DataSegment dataSegment : dataSource.getSegments()) { - Map> interval = retVal.get(dataSegment.getInterval()); - if (interval == null) { - Map> tmp = Maps.newHashMap(); - retVal.put(dataSegment.getInterval(), tmp); - } - setProperties(retVal, dataSource, dataSegment); + final Map>> retVal = Maps.newTreeMap(comparator); + for (DruidDataSource dataSource : datasources) { + for (DataSegment dataSegment : dataSource.getSegments()) { + Map> interval = retVal.get(dataSegment.getInterval()); + if (interval == null) { + Map> tmp = Maps.newHashMap(); + retVal.put(dataSegment.getInterval(), tmp); } + setProperties(retVal, dataSource, dataSegment); } + } - return Response.ok(retVal).build(); + return Response.ok(retVal).build(); } @GET @@ -82,13 +94,20 @@ public class IntervalsResource public Response getSpecificIntervals( @PathParam("interval") String interval, @QueryParam("simple") String simple, - @QueryParam("full") String full + @QueryParam("full") String full, + @Context final HttpServletRequest req ) { final Interval theInterval = new Interval(interval.replace("_", "/")); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); + if (full != null) { final Map>> retVal = Maps.newTreeMap(comparator); for (DruidDataSource dataSource : datasources) { diff --git a/server/src/main/java/io/druid/server/http/InventoryViewUtils.java b/server/src/main/java/io/druid/server/http/InventoryViewUtils.java index df39f5e70c1..62cb5109ead 100644 --- a/server/src/main/java/io/druid/server/http/InventoryViewUtils.java +++ b/server/src/main/java/io/druid/server/http/InventoryViewUtils.java @@ -20,18 +20,30 @@ package io.druid.server.http; import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.metamx.common.ISE; +import com.metamx.common.Pair; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.TreeSet; -public class InventoryViewUtils { +public class InventoryViewUtils +{ public static Set getDataSources(InventoryView serverInventoryView) { @@ -64,4 +76,38 @@ public class InventoryViewUtils { ); return dataSources; } + + public static Set getSecuredDataSources( + InventoryView inventoryView, + final AuthorizationInfo authorizationInfo + ) + { + if (authorizationInfo == null) { + throw new ISE("Invalid to call a secured method with null AuthorizationInfo!!"); + } else { + final Map, Access> resourceAccessMap = new HashMap<>(); + return ImmutableSet.copyOf( + Iterables.filter( + getDataSources(inventoryView), + new Predicate() + { + @Override + public boolean apply(DruidDataSource input) + { + Resource resource = new Resource(input.getName(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } + } } diff --git a/server/src/main/java/io/druid/server/http/MetadataResource.java b/server/src/main/java/io/druid/server/http/MetadataResource.java index 294165402f3..e480121b8b9 100644 --- a/server/src/main/java/io/druid/server/http/MetadataResource.java +++ b/server/src/main/java/io/druid/server/http/MetadataResource.java @@ -20,26 +20,42 @@ package io.druid.server.http; import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.inject.Inject; +import com.metamx.common.Pair; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.indexing.overlord.IndexerMetadataStorageCoordinator; import io.druid.metadata.MetadataSegmentManager; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** */ @@ -48,15 +64,18 @@ public class MetadataResource { private final MetadataSegmentManager metadataSegmentManager; private final IndexerMetadataStorageCoordinator metadataStorageCoordinator; + private final AuthConfig authConfig; @Inject public MetadataResource( MetadataSegmentManager metadataSegmentManager, - IndexerMetadataStorageCoordinator metadataStorageCoordinator + IndexerMetadataStorageCoordinator metadataStorageCoordinator, + AuthConfig authConfig ) { this.metadataSegmentManager = metadataSegmentManager; this.metadataStorageCoordinator = metadataStorageCoordinator; + this.authConfig = authConfig; } @GET @@ -64,20 +83,88 @@ public class MetadataResource @Produces(MediaType.APPLICATION_JSON) public Response getDatabaseDataSources( @QueryParam("full") String full, - @QueryParam("includeDisabled") String includeDisabled + @QueryParam("includeDisabled") String includeDisabled, + @Context final HttpServletRequest req ) { Response.ResponseBuilder builder = Response.status(Response.Status.OK); + + final Collection druidDataSources; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + if (includeDisabled != null) { + return builder.entity( + Collections2.filter( + metadataSegmentManager.getAllDatasourceNames(), + new Predicate() + { + @Override + public boolean apply(String input) + { + Resource resource = new Resource(input, ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + )).build(); + } else { + druidDataSources = + Collections2.filter( + metadataSegmentManager.getInventory(), + new Predicate() + { + @Override + public boolean apply(DruidDataSource input) + { + Resource resource = new Resource(input.getName(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } + } else { + druidDataSources = metadataSegmentManager.getInventory(); + } + if (includeDisabled != null) { - return builder.entity(metadataSegmentManager.getAllDatasourceNames()).build(); + return builder.entity( + Collections2.transform( + druidDataSources, + new Function() + { + @Override + public String apply(DruidDataSource input) + { + return input.getName(); + } + } + ) + ).build(); } if (full != null) { - return builder.entity(metadataSegmentManager.getInventory()).build(); + return builder.entity(druidDataSources).build(); } List dataSourceNames = Lists.newArrayList( Iterables.transform( - metadataSegmentManager.getInventory(), + druidDataSources, new Function() { @Override @@ -97,6 +184,7 @@ public class MetadataResource @GET @Path("/datasources/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSource( @PathParam("dataSourceName") final String dataSourceName ) @@ -112,6 +200,7 @@ public class MetadataResource @GET @Path("/datasources/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full @@ -145,13 +234,14 @@ public class MetadataResource @POST @Path("/datasources/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full, List intervals ) { - List segments = null; + List segments; try { segments = metadataStorageCoordinator.getUsedSegmentsForIntervals(dataSourceName, intervals); } @@ -182,6 +272,7 @@ public class MetadataResource @GET @Path("/datasources/{dataSourceName}/segments/{segmentId}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId diff --git a/server/src/main/java/io/druid/server/http/RulesResource.java b/server/src/main/java/io/druid/server/http/RulesResource.java index fdacb228ea6..1d93d61df7d 100644 --- a/server/src/main/java/io/druid/server/http/RulesResource.java +++ b/server/src/main/java/io/druid/server/http/RulesResource.java @@ -21,13 +21,14 @@ package io.druid.server.http; import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; - +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditEntry; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.metadata.MetadataRuleManager; import io.druid.server.coordinator.rules.Rule; - +import io.druid.server.http.security.RulesResourceFilter; +import io.druid.server.http.security.StateResourceFilter; import org.joda.time.Interval; import javax.servlet.http.HttpServletRequest; @@ -43,7 +44,6 @@ import javax.ws.rs.QueryParam; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; - import java.util.List; /** @@ -66,6 +66,7 @@ public class RulesResource @GET @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getRules() { return Response.ok(databaseRuleManager.getAllRules()).build(); @@ -74,6 +75,7 @@ public class RulesResource @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response getDatasourceRules( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("full") final String full @@ -91,6 +93,7 @@ public class RulesResource @POST @Path("/{dataSourceName}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response setDatasourceRules( @PathParam("dataSourceName") final String dataSourceName, final List rules, @@ -112,6 +115,7 @@ public class RulesResource @GET @Path("/{dataSourceName}/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response getDatasourceRuleHistory( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("interval") final String interval, @@ -131,6 +135,7 @@ public class RulesResource @GET @Path("/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getDatasourceRuleHistory( @QueryParam("interval") final String interval, @QueryParam("count") final Integer count diff --git a/server/src/main/java/io/druid/server/http/ServersResource.java b/server/src/main/java/io/druid/server/http/ServersResource.java index 33665fda81d..70308eb8ebb 100644 --- a/server/src/main/java/io/druid/server/http/ServersResource.java +++ b/server/src/main/java/io/druid/server/http/ServersResource.java @@ -25,8 +25,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import javax.ws.rs.GET; @@ -41,6 +43,7 @@ import java.util.Map; /** */ @Path("/druid/coordinator/v1/servers") +@ResourceFilters(StateResourceFilter.class) public class ServersResource { private static Map makeSimpleServer(DruidServer input) diff --git a/server/src/main/java/io/druid/server/http/TiersResource.java b/server/src/main/java/io/druid/server/http/TiersResource.java index 6990dae2839..db9189e56e5 100644 --- a/server/src/main/java/io/druid/server/http/TiersResource.java +++ b/server/src/main/java/io/druid/server/http/TiersResource.java @@ -28,9 +28,11 @@ import com.google.common.collect.Sets; import com.google.common.collect.Table; import com.google.inject.Inject; import com.metamx.common.MapUtils; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import org.joda.time.Interval; @@ -47,6 +49,7 @@ import java.util.Set; /** */ @Path("/druid/coordinator/v1/tiers") +@ResourceFilters(StateResourceFilter.class) public class TiersResource { private final InventoryView serverInventoryView; diff --git a/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java b/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java new file mode 100644 index 00000000000..a8a1fb4cb4e --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java @@ -0,0 +1,89 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import com.sun.jersey.spi.container.ContainerRequestFilter; +import com.sun.jersey.spi.container.ContainerResponseFilter; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Context; + +public abstract class AbstractResourceFilter implements ResourceFilter, ContainerRequestFilter +{ + //https://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-520005 + @Context + private HttpServletRequest req; + + private final AuthConfig authConfig; + + @Inject + public AbstractResourceFilter(AuthConfig authConfig) + { + this.authConfig = authConfig; + } + + @Override + public ContainerRequestFilter getRequestFilter() + { + return this; + } + + @Override + public ContainerResponseFilter getResponseFilter() + { + return null; + } + + public HttpServletRequest getReq() + { + return req; + } + + public AuthConfig getAuthConfig() + { + return authConfig; + } + + public AbstractResourceFilter setReq(HttpServletRequest req) + { + this.req = req; + return this; + } + + protected Action getAction(ContainerRequest request) + { + Action action; + switch (request.getMethod()) { + case "GET": + case "HEAD": + action = Action.READ; + break; + default: + action = Action.WRITE; + } + return action; + } + + public abstract boolean isApplicable(String requestPath); +} diff --git a/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java b/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java new file mode 100644 index 00000000000..61fc28f1626 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java @@ -0,0 +1,85 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; + +/** + * Use this ResourceFilter at end points where Druid Cluster configuration is read or written + * Here are some example paths where this filter is used - + * - druid/worker/v1 + * - druid/indexer/v1 + * - druid/coordinator/v1/config + * Note - Currently the resource name for all end points is set to "CONFIG" however if more fine grained access control + * is required the resource name can be set to specific config properties. + */ +public class ConfigResourceFilter extends AbstractResourceFilter +{ + @Inject + public ConfigResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String resourceName = "CONFIG"; + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + + final Access authResult = authorizationInfo.isAuthorized( + new Resource(resourceName, ResourceType.CONFIG), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + return requestPath.startsWith("druid/worker/v1") || + requestPath.startsWith("druid/indexer/v1") || + requestPath.startsWith("druid/coordinator/v1/config"); + } +} diff --git a/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java b/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java new file mode 100644 index 00000000000..ccbeab86600 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java @@ -0,0 +1,110 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + +/** + * Use this ResourceFilter when the datasource information is present after "datasources" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/coordinator/v1/datasources/{dataSourceName}/... + * - druid/coordinator/v1/metadata/datasources/{dataSourceName}/... + * - druid/v2/datasources/{dataSourceName}/... + */ +public class DatasourceResourceFilter extends AbstractResourceFilter +{ + @Inject + public DatasourceResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSourceName = request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("datasources"); + } + } + ) + 1 + ).getPath(); + Preconditions.checkNotNull(dataSourceName); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of( + "druid/coordinator/v1/datasources/", + "druid/coordinator/v1/metadata/datasources/", + "druid/v2/datasources/" + ); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java b/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java new file mode 100644 index 00000000000..0e87fab200f --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java @@ -0,0 +1,106 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + + +/** + * Use this ResourceFilter when the datasource information is present after "rules" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/coordinator/v1/rules/ + * */ + +public class RulesResourceFilter extends AbstractResourceFilter +{ + @Inject + public RulesResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSourceName = request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("rules"); + } + } + ) + 1 + ).getPath(); + Preconditions.checkNotNull(dataSourceName); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of("druid/coordinator/v1/rules/"); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java b/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java new file mode 100644 index 00000000000..b4d9d40195f --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java @@ -0,0 +1,97 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; + +/** + * Use this ResourceFilter at end points where Druid Cluster State is read or written + * Here are some example paths where this filter is used - + * - druid/broker/v1 + * - druid/coordinator/v1 + * - druid/historical/v1 + * - druid/indexer/v1 + * - druid/coordinator/v1/rules + * - druid/coordinator/v1/tiers + * - druid/worker/v1 + * - druid/coordinator/v1/servers + * - status + * Note - Currently the resource name for all end points is set to "STATE" however if more fine grained access control + * is required the resource name can be set to specific state properties. + */ +public class StateResourceFilter extends AbstractResourceFilter +{ + @Inject + public StateResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String resourceName = "STATE"; + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + + final Access authResult = authorizationInfo.isAuthorized( + new Resource(resourceName, ResourceType.STATE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + public boolean isApplicable(String requestPath) + { + return requestPath.startsWith("druid/broker/v1") || + requestPath.startsWith("druid/coordinator/v1") || + requestPath.startsWith("druid/historical/v1") || + requestPath.startsWith("druid/indexer/v1") || + requestPath.startsWith("druid/coordinator/v1/rules") || + requestPath.startsWith("druid/coordinator/v1/tiers") || + requestPath.startsWith("druid/worker/v1") || + requestPath.startsWith("druid/coordinator/v1/servers") || + requestPath.startsWith("status"); + } +} diff --git a/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java b/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java index a0ad9b765b1..66fd4c1a6fd 100644 --- a/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java +++ b/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java @@ -28,11 +28,9 @@ import com.metamx.metrics.AbstractMonitor; import com.metamx.metrics.KeyedDiff; import com.metamx.metrics.MonitorUtils; import io.druid.query.DruidMetrics; -import io.druid.segment.realtime.firehose.EventReceiverFirehoseFactory; import java.util.Map; import java.util.Properties; -import java.util.concurrent.atomic.AtomicLong; public class EventReceiverFirehoseMonitor extends AbstractMonitor { diff --git a/server/src/main/java/io/druid/server/security/Access.java b/server/src/main/java/io/druid/server/security/Access.java new file mode 100644 index 00000000000..a70e579f3a4 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Access.java @@ -0,0 +1,51 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public class Access +{ + private final boolean allowed; + private String message; + + public Access(boolean allowed) { + this(allowed, ""); + } + + public Access(boolean allowed, String message) { + this.allowed = allowed; + this.message = message; + } + + public boolean isAllowed() { + return allowed; + } + + public Access setMessage(String message) + { + this.message = message; + return this; + } + + @Override + public String toString() + { + return String.format("Allowed:%s, Message:%s", allowed, message); + } +} diff --git a/server/src/main/java/io/druid/server/security/Action.java b/server/src/main/java/io/druid/server/security/Action.java new file mode 100644 index 00000000000..2b7606b58dd --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Action.java @@ -0,0 +1,26 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public enum Action +{ + READ, + WRITE +} diff --git a/server/src/main/java/io/druid/server/security/AuthConfig.java b/server/src/main/java/io/druid/server/security/AuthConfig.java new file mode 100644 index 00000000000..8ade4ce6c41 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/AuthConfig.java @@ -0,0 +1,85 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +public class AuthConfig +{ + /** + * Use this String as the attribute name for the request attribute to pass {@link AuthorizationInfo} + * from the servlet filter to the jersey resource + * */ + public static final String DRUID_AUTH_TOKEN = "Druid-Auth-Token"; + + public AuthConfig() { + this(false); + } + + @JsonCreator + public AuthConfig( + @JsonProperty("enabled") boolean enabled + ){ + this.enabled = enabled; + } + /** + * If druid.auth.enabled is set to true then an implementation of AuthorizationInfo + * must be provided and it must be set as a request attribute possibly inside the servlet filter + * injected in the filter chain using your own extension + * */ + @JsonProperty + private final boolean enabled; + + public boolean isEnabled() + { + return enabled; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AuthConfig that = (AuthConfig) o; + + return enabled == that.enabled; + + } + + @Override + public int hashCode() + { + return (enabled ? 1 : 0); + } + + @Override + public String toString() + { + return "AuthConfig{" + + "enabled=" + enabled + + '}'; + } +} diff --git a/server/src/main/java/io/druid/server/security/AuthorizationInfo.java b/server/src/main/java/io/druid/server/security/AuthorizationInfo.java new file mode 100644 index 00000000000..31097a93547 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/AuthorizationInfo.java @@ -0,0 +1,44 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +/** + * This interface should be used to store as well as process Authorization Information + * An extension can be used to inject servlet filter which will create objects of this type + * and set it as a request attribute with attribute name as {@link AuthConfig#DRUID_AUTH_TOKEN}. + * In the jersey resources if the authorization is enabled depending on {@link AuthConfig#enabled} + * the {@link #isAuthorized(Resource, Action)} method will be used to perform authorization checks + * */ +public interface AuthorizationInfo +{ + /** + * Perform authorization checks for the given {@link Resource} and {@link Action}. + * resource and action objects should be instantiated depending on + * the specific endPoint where the check is being performed. + * Modeling Principal and specific way of performing authorization checks is + * entirely implementation dependent. + * + * @param resource information about resource that is being accessed + * @param action action to be performed on the resource + * @return a {@link Access} object having {@link Access#allowed} set to true if authorized otherwise set to false + * and optionally {@link Access#message} set to appropriate message + * */ + Access isAuthorized(Resource resource, Action action); +} diff --git a/server/src/main/java/io/druid/server/security/Resource.java b/server/src/main/java/io/druid/server/security/Resource.java new file mode 100644 index 00000000000..d3c74fb5289 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Resource.java @@ -0,0 +1,69 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public class Resource +{ + private final String name; + private final ResourceType type; + + public Resource(String name, ResourceType type) + { + this.name = name; + this.type = type; + } + + public String getName() + { + return name; + } + + public ResourceType getType() + { + return type; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Resource resource = (Resource) o; + + if (!name.equals(resource.name)) { + return false; + } + return type == resource.type; + + } + + @Override + public int hashCode() + { + int result = name.hashCode(); + result = 31 * result + type.hashCode(); + return result; + } +} diff --git a/server/src/main/java/io/druid/server/security/ResourceType.java b/server/src/main/java/io/druid/server/security/ResourceType.java new file mode 100644 index 00000000000..818bf9ca947 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/ResourceType.java @@ -0,0 +1,27 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public enum ResourceType +{ + DATASOURCE, + CONFIG, + STATE +} diff --git a/server/src/test/java/io/druid/server/ClientInfoResourceTest.java b/server/src/test/java/io/druid/server/ClientInfoResourceTest.java index a81938a7284..1436ab2534b 100644 --- a/server/src/test/java/io/druid/server/ClientInfoResourceTest.java +++ b/server/src/test/java/io/druid/server/ClientInfoResourceTest.java @@ -47,6 +47,7 @@ import io.druid.client.TimelineServerView; import io.druid.client.selector.ServerSelector; import io.druid.query.TableDataSource; import io.druid.query.metadata.SegmentMetadataQueryConfig; +import io.druid.server.security.AuthConfig; import io.druid.timeline.DataSegment; import io.druid.timeline.VersionedIntervalTimeline; import io.druid.timeline.partition.NumberedShardSpec; @@ -411,7 +412,7 @@ public class ClientInfoResourceTest SegmentMetadataQueryConfig segmentMetadataQueryConfig ) { - return new ClientInfoResource(serverInventoryView, timelineServerView, segmentMetadataQueryConfig) + return new ClientInfoResource(serverInventoryView, timelineServerView, segmentMetadataQueryConfig, new AuthConfig()) { @Override protected DateTime getCurrentTime() diff --git a/server/src/test/java/io/druid/server/QueryResourceTest.java b/server/src/test/java/io/druid/server/QueryResourceTest.java index ed2b3f1091f..dabd6b575e8 100644 --- a/server/src/test/java/io/druid/server/QueryResourceTest.java +++ b/server/src/test/java/io/druid/server/QueryResourceTest.java @@ -20,9 +20,13 @@ package io.druid.server; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.metamx.common.guava.Sequence; import com.metamx.common.guava.Sequences; import com.metamx.emitter.service.ServiceEmitter; +import io.druid.concurrent.Execs; import io.druid.jackson.DefaultObjectMapper; import io.druid.query.Query; import io.druid.query.QueryRunner; @@ -31,9 +35,15 @@ import io.druid.query.SegmentDescriptor; import io.druid.server.initialization.ServerConfig; import io.druid.server.log.NoopRequestLogger; import io.druid.server.metrics.NoopServiceEmitter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import org.easymock.EasyMock; import org.joda.time.Interval; import org.joda.time.Period; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; @@ -45,6 +55,8 @@ import javax.ws.rs.core.Response; import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; /** * @@ -97,6 +109,9 @@ public class QueryResourceTest private static final ServiceEmitter noopServiceEmitter = new NoopServiceEmitter(); + private QueryResource queryResource; + private QueryManager queryManager; + @BeforeClass public static void staticSetup() { @@ -106,9 +121,19 @@ public class QueryResourceTest @Before public void setup() { - EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON); + EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON).anyTimes(); EasyMock.expect(testServletRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); - EasyMock.replay(testServletRequest); + queryManager = new QueryManager(); + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig() + ); } private static final String simpleTimeSeriesQuery = "{\n" @@ -129,42 +154,273 @@ public class QueryResourceTest @Test public void testGoodQuery() throws IOException { - QueryResource queryResource = new QueryResource( - serverConfig, - jsonMapper, - jsonMapper, - testSegmentWalker, - new NoopServiceEmitter(), - new NoopRequestLogger(), - new QueryManager() - ); - Response respone = queryResource.doPost( + EasyMock.replay(testServletRequest); + Response response = queryResource.doPost( new ByteArrayInputStream(simpleTimeSeriesQuery.getBytes("UTF-8")), null /*pretty*/, testServletRequest ); - Assert.assertNotNull(respone); + Assert.assertNotNull(response); } @Test public void testBadQuery() throws IOException { + EasyMock.replay(testServletRequest); + Response response = queryResource.doPost( + new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + Assert.assertNotNull(response); + Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + } - QueryResource queryResource = new QueryResource( + @Test + public void testSecuredQuery() throws Exception + { + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("allow")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( serverConfig, jsonMapper, jsonMapper, testSegmentWalker, new NoopServiceEmitter(), new NoopRequestLogger(), - new QueryManager() + queryManager, + new AuthConfig(true) ); - Response respone = queryResource.doPost( - new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes("UTF-8")), + + Response response = queryResource.doPost( + new ByteArrayInputStream(simpleTimeSeriesQuery.getBytes("UTF-8")), null /*pretty*/, testServletRequest ); - Assert.assertNotNull(respone); - Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), respone.getStatus()); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + + response = queryResource.doPost( + new ByteArrayInputStream("{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"}".getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + + } + + @Test(timeout = 60_000L) + public void testSecuredGetServer() throws Exception + { + final CountDownLatch waitForCancellationLatch = new CountDownLatch(1); + final CountDownLatch waitFinishLatch = new CountDownLatch(2); + final CountDownLatch startAwaitLatch = new CountDownLatch(1); + final CountDownLatch cancelledCountDownLatch = new CountDownLatch(1); + + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + // READ action corresponds to the query + // WRITE corresponds to cancellation of query + if (action.equals(Action.READ)) { + try { + waitForCancellationLatch.await(); + } + catch (InterruptedException e) { + // When the query is cancelled the control will reach here, + // countdown the latch and rethrow the exception so that error response is returned for the query + cancelledCountDownLatch.countDown(); + Throwables.propagate(e); + } + return new Access(true); + } else { + return new Access(true); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig(true) + ); + + final String queryString = "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"," + + "\"context\":{\"queryId\":\"id_1\"}}"; + ObjectMapper mapper = new DefaultObjectMapper(); + Query query = mapper.readValue(queryString, Query.class); + + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded("test_query_resource_%s") + ).submit( + new Runnable() + { + @Override + public void run() + { + try { + startAwaitLatch.countDown(); + Response response = queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes("UTF-8")), + null, + testServletRequest + ); + + Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + } + catch (IOException e) { + Throwables.propagate(e); + } + waitFinishLatch.countDown(); + } + } + ); + + queryManager.registerQuery(query, future); + startAwaitLatch.await(); + + Executors.newSingleThreadExecutor().submit( + new Runnable() + { + @Override + public void run() + { + Response response = queryResource.getServer("id_1", testServletRequest); + Assert.assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); + } + } + ); + waitFinishLatch.await(); + cancelledCountDownLatch.await(); + } + + @Test(timeout = 60_000L) + public void testDenySecuredGetServer() throws Exception + { + final CountDownLatch waitForCancellationLatch = new CountDownLatch(1); + final CountDownLatch waitFinishLatch = new CountDownLatch(2); + final CountDownLatch startAwaitLatch = new CountDownLatch(1); + + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + // READ action corresponds to the query + // WRITE corresponds to cancellation of query + if (action.equals(Action.READ)) { + try { + waitForCancellationLatch.await(); + } + catch (InterruptedException e) { + Throwables.propagate(e); + } + return new Access(true); + } else { + // Deny access to cancel the query + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig(true) + ); + + final String queryString = "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"," + + "\"context\":{\"queryId\":\"id_1\"}}"; + ObjectMapper mapper = new DefaultObjectMapper(); + Query query = mapper.readValue(queryString, Query.class); + + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded("test_query_resource_%s") + ).submit( + new Runnable() + { + @Override + public void run() + { + try { + startAwaitLatch.countDown(); + Response response = queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes("UTF-8")), + null, + testServletRequest + ); + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + catch (IOException e) { + Throwables.propagate(e); + } + waitFinishLatch.countDown(); + } + } + ); + + queryManager.registerQuery(query, future); + startAwaitLatch.await(); + + Executors.newSingleThreadExecutor().submit( + new Runnable() + { + @Override + public void run() + { + Response response = queryResource.getServer("id_1", testServletRequest); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); + } + } + ); + waitFinishLatch.await(); + } + + @After + public void tearDown() + { + EasyMock.verify(testServletRequest); } } diff --git a/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java b/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java index 51f5cbb8852..71147cdaa7b 100644 --- a/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java +++ b/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java @@ -25,6 +25,11 @@ import io.druid.client.CoordinatorServerView; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.indexing.IndexingServiceClient; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import io.druid.timeline.DataSegment; import org.easymock.EasyMock; import org.joda.time.Interval; @@ -32,6 +37,7 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.util.ArrayList; import java.util.HashMap; @@ -47,10 +53,12 @@ public class DatasourcesResourceTest private DruidServer server; private List listDataSources; private List dataSegmentList; + private HttpServletRequest request; @Before public void setUp() { + request = EasyMock.createStrictMock(HttpServletRequest.class); inventoryView = EasyMock.createStrictMock(CoordinatorServerView.class); server = EasyMock.createStrictMock(DruidServer.class); dataSegmentList = new ArrayList<>(); @@ -94,8 +102,12 @@ public class DatasourcesResourceTest ) ); listDataSources = new ArrayList<>(); - listDataSources.add(new DruidDataSource("datasource1", new HashMap()).addSegment("part1", dataSegmentList.get(0))); - listDataSources.add(new DruidDataSource("datasource2", new HashMap()).addSegment("part1", dataSegmentList.get(1))); + listDataSources.add( + new DruidDataSource("datasource1", new HashMap()).addSegment("part1", dataSegmentList.get(0)) + ); + listDataSources.add( + new DruidDataSource("datasource2", new HashMap()).addSegment("part1", dataSegmentList.get(1)) + ); } @Test @@ -108,8 +120,8 @@ public class DatasourcesResourceTest ImmutableList.of(server) ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); - Response response = datasourcesResource.getQueryableDataSources("full", null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); + Response response = datasourcesResource.getQueryableDataSources("full", null, request); Set result = (Set) response.getEntity(); DruidDataSource[] resultantDruidDataSources = new DruidDataSource[result.size()]; result.toArray(resultantDruidDataSources); @@ -117,7 +129,7 @@ public class DatasourcesResourceTest Assert.assertEquals(2, resultantDruidDataSources.length); Assert.assertArrayEquals(listDataSources.toArray(), resultantDruidDataSources); - response = datasourcesResource.getQueryableDataSources(null, null); + response = datasourcesResource.getQueryableDataSources(null, null, request); List result1 = (List) response.getEntity(); Assert.assertEquals(200, response.getStatus()); Assert.assertEquals(2, result1.size()); @@ -126,6 +138,53 @@ public class DatasourcesResourceTest EasyMock.verify(inventoryView, server); } + @Test + public void testSecuredGetFullQueryableDataSources() throws Exception + { + EasyMock.expect(server.getDataSources()).andReturn( + ImmutableList.of(listDataSources.get(0), listDataSources.get(1)) + ).atLeastOnce(); + EasyMock.expect(inventoryView.getInventory()).andReturn( + ImmutableList.of(server) + ).atLeastOnce(); + EasyMock.expect(request.getAttribute(AuthConfig.DRUID_AUTH_TOKEN)).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("datasource1")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(inventoryView, server, request); + + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig(true)); + Response response = datasourcesResource.getQueryableDataSources("full", null, request); + Set result = (Set) response.getEntity(); + DruidDataSource[] resultantDruidDataSources = new DruidDataSource[result.size()]; + result.toArray(resultantDruidDataSources); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(1, resultantDruidDataSources.length); + Assert.assertArrayEquals(listDataSources.subList(0, 1).toArray(), resultantDruidDataSources); + + response = datasourcesResource.getQueryableDataSources(null, null, request); + List result1 = (List) response.getEntity(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(1, result1.size()); + Assert.assertTrue(result1.contains("datasource1")); + + EasyMock.verify(inventoryView, server, request); + } + @Test public void testGetSimpleQueryableDataSources() throws Exception { @@ -145,8 +204,8 @@ public class DatasourcesResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); - Response response = datasourcesResource.getQueryableDataSources(null, "simple"); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); + Response response = datasourcesResource.getQueryableDataSources(null, "simple", request); Assert.assertEquals(200, response.getStatus()); List> results = (List>) response.getEntity(); int index = 0; @@ -172,7 +231,7 @@ public class DatasourcesResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", "full"); DruidDataSource result = (DruidDataSource) response.getEntity(); Assert.assertEquals(200, response.getStatus()); @@ -189,7 +248,7 @@ public class DatasourcesResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Assert.assertEquals(204, datasourcesResource.getTheDataSource("none", null).getStatus()); EasyMock.verify(inventoryView, server); } @@ -211,7 +270,7 @@ public class DatasourcesResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", null); Assert.assertEquals(200, response.getStatus()); Map> result = (Map>) response.getEntity(); @@ -250,7 +309,7 @@ public class DatasourcesResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView, server, server2, server3); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", null); Assert.assertEquals(200, response.getStatus()); Map> result = (Map>) response.getEntity(); @@ -281,7 +340,7 @@ public class DatasourcesResourceTest List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-22T00:00:00.000Z/2010-01-23T00:00:00.000Z")); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getSegmentDataSourceIntervals("invalidDataSource", null, null); Assert.assertEquals(response.getEntity(), null); @@ -328,7 +387,7 @@ public class DatasourcesResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getSegmentDataSourceSpecificInterval( "invalidDataSource", "2010-01-01/P1D", @@ -395,7 +454,7 @@ public class DatasourcesResourceTest EasyMock.expectLastCall().once(); EasyMock.replay(indexingServiceClient, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient, new AuthConfig()); Response response = datasourcesResource.deleteDataSourceSpecificInterval("datasource1", interval); Assert.assertEquals(200, response.getStatus()); @@ -407,7 +466,7 @@ public class DatasourcesResourceTest public void testDeleteDataSource() { IndexingServiceClient indexingServiceClient = EasyMock.createStrictMock(IndexingServiceClient.class); EasyMock.replay(indexingServiceClient, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient, new AuthConfig()); Response response = datasourcesResource.deleteDataSource("datasource", "true", "???"); Assert.assertEquals(400, response.getStatus()); Assert.assertNotNull(response.getEntity()); diff --git a/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java b/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java index b77842bff8d..4fb50795c85 100644 --- a/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java +++ b/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java @@ -22,13 +22,16 @@ package io.druid.server.http; import com.google.common.collect.ImmutableList; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.security.AuthConfig; import io.druid.timeline.DataSegment; import org.easymock.EasyMock; import org.joda.time.Interval; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.util.ArrayList; import java.util.List; @@ -40,12 +43,15 @@ public class IntervalsResourceTest private InventoryView inventoryView; private DruidServer server; private List dataSegmentList; + private HttpServletRequest request; @Before public void setUp() { inventoryView = EasyMock.createStrictMock(InventoryView.class); server = EasyMock.createStrictMock(DruidServer.class); + request = EasyMock.createStrictMock(HttpServletRequest.class); + dataSegmentList = new ArrayList<>(); dataSegmentList.add( new DataSegment( @@ -103,9 +109,9 @@ public class IntervalsResourceTest List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); expectedIntervals.add(new Interval("2010-01-22T00:00:00.000Z/2010-01-23T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getIntervals(); + Response response = intervalsResource.getIntervals(request); TreeMap>> actualIntervals = (TreeMap) response.getEntity(); Assert.assertEquals(2, actualIntervals.size()); Assert.assertEquals(expectedIntervals.get(1), actualIntervals.firstKey()); @@ -117,7 +123,6 @@ public class IntervalsResourceTest Assert.assertEquals(5L, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("size")); Assert.assertEquals(1, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("count")); - EasyMock.verify(inventoryView); } @Test @@ -130,16 +135,15 @@ public class IntervalsResourceTest List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", "simple", null); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", "simple", null, request); Map> actualIntervals = (Map) response.getEntity(); Assert.assertEquals(1, actualIntervals.size()); Assert.assertTrue(actualIntervals.containsKey(expectedIntervals.get(0))); Assert.assertEquals(25L, actualIntervals.get(expectedIntervals.get(0)).get("size")); Assert.assertEquals(2, actualIntervals.get(expectedIntervals.get(0)).get("count")); - EasyMock.verify(inventoryView); } @Test @@ -152,9 +156,9 @@ public class IntervalsResourceTest List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, "full"); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, "full", request); TreeMap>> actualIntervals = (TreeMap) response.getEntity(); Assert.assertEquals(1, actualIntervals.size()); Assert.assertEquals(expectedIntervals.get(0), actualIntervals.firstKey()); @@ -163,7 +167,6 @@ public class IntervalsResourceTest Assert.assertEquals(5L, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("size")); Assert.assertEquals(1, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("count")); - EasyMock.verify(inventoryView); } @Test @@ -174,14 +177,19 @@ public class IntervalsResourceTest ).atLeastOnce(); EasyMock.replay(inventoryView); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, null); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, null, request); Map actualIntervals = (Map) response.getEntity(); Assert.assertEquals(2, actualIntervals.size()); Assert.assertEquals(25L, actualIntervals.get("size")); Assert.assertEquals(2, actualIntervals.get("count")); + } + + @After + public void tearDown() { EasyMock.verify(inventoryView); } + } diff --git a/server/src/test/java/io/druid/server/http/RulesResourceTest.java b/server/src/test/java/io/druid/server/http/RulesResourceTest.java index 283026f82cf..d153397cee9 100644 --- a/server/src/test/java/io/druid/server/http/RulesResourceTest.java +++ b/server/src/test/java/io/druid/server/http/RulesResourceTest.java @@ -20,12 +20,10 @@ package io.druid.server.http; import com.google.common.collect.ImmutableList; - import io.druid.audit.AuditEntry; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.metadata.MetadataRuleManager; - import org.easymock.EasyMock; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -34,7 +32,6 @@ import org.junit.Before; import org.junit.Test; import javax.ws.rs.core.Response; - import java.util.List; import java.util.Map; @@ -255,4 +252,5 @@ public class RulesResourceTest EasyMock.verify(auditManager); } + } diff --git a/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java b/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java new file mode 100644 index 00000000000..ae317314b21 --- /dev/null +++ b/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java @@ -0,0 +1,245 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Binder; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import com.sun.jersey.spi.container.ContainerRequest; +import com.sun.jersey.spi.container.ResourceFilter; +import com.sun.jersey.spi.container.ResourceFilters; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import org.easymock.EasyMock; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.PathSegment; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +public class ResourceFilterTestHelper +{ + public HttpServletRequest req; + public AuthorizationInfo authorizationInfo; + public ContainerRequest request; + + public void setUp(ResourceFilter resourceFilter) throws Exception + { + req = EasyMock.createStrictMock(HttpServletRequest.class); + request = EasyMock.createStrictMock(ContainerRequest.class); + authorizationInfo = EasyMock.createStrictMock(AuthorizationInfo.class); + + // Memory barrier + synchronized (this) { + ((AbstractResourceFilter) resourceFilter).setReq(req); + } + } + + public void setUpMockExpectations( + String requestPath, + boolean authCheckResult, + String requestMethod + ) + { + EasyMock.expect(request.getPath()).andReturn(requestPath).anyTimes(); + EasyMock.expect(request.getPathSegments()).andReturn( + ImmutableList.copyOf( + Iterables.transform( + Arrays.asList(requestPath.split("/")), + new Function() + { + @Override + public PathSegment apply(final String input) + { + return new PathSegment() + { + @Override + public String getPath() + { + return input; + } + + @Override + public MultivaluedMap getMatrixParameters() + { + return null; + } + }; + } + } + ) + ) + ).anyTimes(); + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.expect(req.getAttribute(EasyMock.anyString())).andReturn(authorizationInfo).atLeastOnce(); + EasyMock.expect(authorizationInfo.isAuthorized( + EasyMock.anyObject(Resource.class), + EasyMock.anyObject(Action.class) + )).andReturn( + new Access(authCheckResult) + ).atLeastOnce(); + + } + + public static Collection getRequestPaths(final Class clazz) + { + return getRequestPaths(clazz, ImmutableList.>of(), ImmutableList.>of()); + } + + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections + ) + { + return getRequestPaths(clazz, mockableInjections, ImmutableList.>of()); + } + + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections, + final Iterable> mockableKeys + ) + { + return getRequestPaths(clazz, mockableInjections, mockableKeys, ImmutableList.of()); + } + + // Feeds in an array of [ PathName, MethodName, ResourceFilter , Injector] + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections, + final Iterable> mockableKeys, + final Iterable injectedObjs + ) + { + final Injector injector = Guice.createInjector( + new Module() + { + @Override + public void configure(Binder binder) + { + for (Class clazz : mockableInjections) { + binder.bind(clazz).toInstance(EasyMock.createNiceMock(clazz)); + } + for (Object obj : injectedObjs) { + binder.bind((Class) obj.getClass()).toInstance(obj); + } + for (Key key : mockableKeys) { + binder.bind((Key) key).toInstance(EasyMock.createNiceMock(key.getTypeLiteral().getRawType())); + } + binder.bind(AuthConfig.class).toInstance(new AuthConfig(true)); + } + } + ); + final String basepath = ((Path) clazz.getAnnotation(Path.class)).value().substring(1); //Ignore the first "/" + final List> baseResourceFilters = + clazz.getAnnotation(ResourceFilters.class) == null ? Collections.>emptyList() : + ImmutableList.copyOf(((ResourceFilters) clazz.getAnnotation(ResourceFilters.class)).value()); + + return ImmutableList.copyOf( + Iterables.concat( + // Step 3 - Merge all the Objects arrays for each endpoints + Iterables.transform( + // Step 2 - + // For each endpoint, make an Object array containing + // - Request Path like "druid/../../.." + // - Request Method like "GET" or "POST" or "DELETE" + // - Resource Filter instance for the endpoint + Iterables.filter( + // Step 1 - + // Filter out non resource endpoint methods + // and also the endpoints that does not have any + // ResourceFilters applied to them + ImmutableList.copyOf(clazz.getDeclaredMethods()), + new Predicate() + { + @Override + public boolean apply(Method input) + { + return input.getAnnotation(GET.class) != null + || input.getAnnotation(POST.class) != null + || input.getAnnotation(DELETE.class) != null + && (input.getAnnotation(ResourceFilters.class) != null + || !baseResourceFilters.isEmpty()); + } + } + ), + new Function>() + { + @Override + public Collection apply(final Method method) + { + final List> resourceFilters = + method.getAnnotation(ResourceFilters.class) == null ? baseResourceFilters : + ImmutableList.copyOf(method.getAnnotation(ResourceFilters.class).value()); + + return Collections2.transform( + resourceFilters, + new Function, Object[]>() + { + @Override + public Object[] apply(Class input) + { + if (method.getAnnotation(Path.class) != null) { + return new Object[]{ + String.format("%s%s", basepath, method.getAnnotation(Path.class).value()), + input.getAnnotation(GET.class) == null ? (method.getAnnotation(DELETE.class) == null + ? "POST" + : "DELETE") : "GET", + injector.getInstance(input), + injector + }; + } else { + return new Object[]{ + basepath, + input.getAnnotation(GET.class) == null ? (method.getAnnotation(DELETE.class) == null + ? "POST" + : "DELETE") : "GET", + injector.getInstance(input), + injector + }; + } + } + } + ); + } + } + ) + ) + ); + } +} diff --git a/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java b/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java new file mode 100644 index 00000000000..4a7cd0de825 --- /dev/null +++ b/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Injector; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.server.ClientInfoResource; +import io.druid.server.QueryResource; +import io.druid.server.StatusResource; +import io.druid.server.http.BrokerResource; +import io.druid.server.http.CoordinatorDynamicConfigsResource; +import io.druid.server.http.CoordinatorResource; +import io.druid.server.http.DatasourcesResource; +import io.druid.server.http.HistoricalResource; +import io.druid.server.http.IntervalsResource; +import io.druid.server.http.MetadataResource; +import io.druid.server.http.RulesResource; +import io.druid.server.http.ServersResource; +import io.druid.server.http.TiersResource; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import java.util.Collection; + +@RunWith(Parameterized.class) +public class SecurityResourceFilterTest extends ResourceFilterTestHelper +{ + @Parameterized.Parameters + public static Collection data() + { + return ImmutableList.copyOf( + Iterables.concat( + getRequestPaths(CoordinatorResource.class), + getRequestPaths(DatasourcesResource.class), + getRequestPaths(BrokerResource.class), + getRequestPaths(HistoricalResource.class), + getRequestPaths(IntervalsResource.class), + getRequestPaths(MetadataResource.class), + getRequestPaths(RulesResource.class), + getRequestPaths(ServersResource.class), + getRequestPaths(TiersResource.class), + getRequestPaths(ClientInfoResource.class), + getRequestPaths(CoordinatorDynamicConfigsResource.class), + getRequestPaths(QueryResource.class), + getRequestPaths(StatusResource.class) + ) + ); + } + + private final String requestPath; + private final String requestMethod; + private final ResourceFilter resourceFilter; + private final Injector injector; + + public SecurityResourceFilterTest( + String requestPath, + String requestMethod, + ResourceFilter resourceFilter, + Injector injector + ) + { + this.requestPath = requestPath; + this.requestMethod = requestMethod; + this.resourceFilter = resourceFilter; + this.injector = injector; + } + + @Before + public void setUp() throws Exception + { + setUp(resourceFilter); + } + + @Test + public void testDatasourcesResourcesFilteringAccess() + { + setUpMockExpectations(requestPath, true, requestMethod); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + resourceFilter.getRequestFilter().filter(request); + EasyMock.verify(req, request, authorizationInfo); + } + + @Test(expected = WebApplicationException.class) + public void testDatasourcesResourcesFilteringNoAccess() + { + setUpMockExpectations(requestPath, false, requestMethod); + EasyMock.replay(req, request, authorizationInfo); + //Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + try { + resourceFilter.getRequestFilter().filter(request); + } + catch (WebApplicationException e) { + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), e.getResponse().getStatus()); + throw e; + } + EasyMock.verify(req, request, authorizationInfo); + } + + @Test + public void testDatasourcesResourcesFilteringBadPath() + { + EasyMock.replay(req, request, authorizationInfo); + final String badRequestPath = requestPath.replaceAll("\\w+", "droid"); + Assert.assertFalse(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(badRequestPath)); + EasyMock.verify(req, request, authorizationInfo); + } + +}