Add support for authorizing query context params (#12396)

The query context is a way that the user gives a hint to the Druid query engine, so that they enforce a certain behavior or at least let the query engine prefer a certain plan during query planning. Today, there are 3 types of query context params as below.

Default context params. They are set via druid.query.default.context in runtime properties. Any user context params can be default params.
User context params. They are set in the user query request. See https://druid.apache.org/docs/latest/querying/query-context.html for parameters.
System context params. They are set by the Druid query engine during query processing. These params override other context params.
Today, any context params are allowed to users. This can cause 
1) a bad UX if the context param is not matured yet or 
2) even query failure or system fault in the worst case if a sensitive param is abused, ex) maxSubqueryRows.

This PR adds an ability to limit context params per user role. That means, a query will fail if you have a context param set in the query that is not allowed to you. To do that, this PR adds a new built-in resource type, QUERY_CONTEXT. The resource to authorize has a name of the context param (such as maxSubqueryRows) and the type of QUERY_CONTEXT. To allow a certain context param for a user, the user should be granted WRITE permission on the context param resource. Here is an example of the permission.

{
  "resourceAction" : {
    "resource" : {
      "name" : "maxSubqueryRows",
      "type" : "QUERY_CONTEXT"
    },
    "action" : "WRITE"
  },
  "resourceNamePattern" : "maxSubqueryRows"
}
Each role can have multiple permissions for context params. Each permission should be set for different context params.

When a query is issued with a query context X, the query will fail if the user who issued the query does not have WRITE permission on the query context X. In this case,

HTTP endpoints will return 403 response code.
JDBC will throw ForbiddenException.
Note: there is a context param called brokerService that is used only by the router. This param is used to pin your query to run it in a specific broker. Because the authorization is done not in the router, but in the broker, if you have brokerService set in your query without a proper permission, your query will fail in the broker after routing is done. Technically, this is not right because the authorization is checked after the context param takes effect. However, this should not cause any user-facing issue and thus should be OK. The query will still fail if the user doesn’t have permission for brokerService.

The context param authorization can be enabled using druid.auth.authorizeQueryContextParams. This is disabled by default to avoid any hassle when someone upgrades his cluster blindly without reading release notes.
This commit is contained in:
Jihoon Son 2022-04-21 01:51:16 -07:00 committed by GitHub
parent 4c6ba73823
commit 73ce5df22d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 1623 additions and 500 deletions

View File

@ -28,6 +28,7 @@ import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.BaseQuery;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.query.filter.DimFilter;
@ -145,6 +146,12 @@ public class MaterializedViewQuery<T> implements Query<T>
return query.getContext();
}
@Override
public QueryContext getQueryContext()
{
return query.getQueryContext();
}
@Override
public <ContextType> ContextType getContextValue(String key)
{

View File

@ -32,8 +32,8 @@ import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchAggregatorFactory;
import org.apache.druid.query.aggregation.datasketches.quantiles.DoublesSketchToQuantilePostAggregator;
@ -50,7 +50,6 @@ import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
{
@ -200,11 +199,12 @@ public class DoublesSketchApproxQuantileSqlAggregator implements SqlAggregator
);
}
@Nullable
static Long getMaxStreamLengthFromQueryContext(Map<String, Object> queryContext)
static long getMaxStreamLengthFromQueryContext(QueryContext queryContext)
{
final Object val = queryContext.get(CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH);
return val == null ? null : Numbers.parseLong(val);
return queryContext.getAsLong(
CTX_APPROX_QUANTILE_DS_MAX_STREAM_LENGTH,
DoublesSketchAggregatorFactory.DEFAULT_MAX_STREAM_LENGTH
);
}
private static class DoublesSketchApproxQuantileSqlAggFunction extends SqlAggFunction

View File

@ -36,6 +36,7 @@ druid_auth_authenticator_basic_type=basic
druid_auth_authenticatorChain=["basic"]
druid_auth_authorizer_basic_type=basic
druid_auth_authorizers=["basic"]
druid_auth_authorizeQueryContextParams=true
druid_client_https_certAlias=druid
druid_client_https_keyManagerPassword=druid123
druid_client_https_keyStorePassword=druid123

View File

@ -46,6 +46,7 @@ druid_auth_authorizer_ldapauth_initialAdminUser=admin
druid_auth_authorizer_ldapauth_initialAdminRole=admin
druid_auth_authorizer_ldapauth_roleProvider_type=ldap
druid_auth_authorizers=["ldapauth"]
druid_auth_authorizeQueryContextParams=true
druid_client_https_certAlias=druid
druid_client_https_keyManagerPassword=druid123
druid_client_https_keyStorePassword=druid123

View File

@ -154,3 +154,21 @@ objectClass: groupOfUniqueNames
cn: datasourceWithSysGroup
description: datasourceWithSysGroup users
uniqueMember: uid=datasourceAndSysUser,ou=Users,dc=example,dc=org
dn: uid=datasourceAndContextParamsUser,ou=Users,dc=example,dc=org
uid: datasourceAndContextParamsUser
cn: datasourceAndContextParamsUser
sn: datasourceAndContextParamsUser
objectClass: top
objectClass: posixAccount
objectClass: inetOrgPerson
homeDirectory: /home/datasourceAndContextParamsUser
uidNumber: 9
gidNumber: 9
userPassword: helloworld
dn: cn=datasourceAndContextParamsGroup,ou=Groups,dc=example,dc=org
objectClass: groupOfUniqueNames
cn: datasourceAndContextParamsGroup
description: datasourceAndContextParamsGroup users
uniqueMember: uid=datasourceAndContextParamsUser,ou=Users,dc=example,dc=org

View File

@ -26,7 +26,6 @@ import org.apache.druid.java.util.http.client.HttpClient;
import org.apache.druid.java.util.http.client.Request;
import org.apache.druid.java.util.http.client.response.StatusResponseHandler;
import org.apache.druid.java.util.http.client.response.StatusResponseHolder;
import org.apache.druid.testing.clients.AbstractQueryResourceTestClient;
import org.jboss.netty.handler.codec.http.HttpMethod;
import org.jboss.netty.handler.codec.http.HttpResponseStatus;
@ -36,7 +35,7 @@ import java.net.URL;
public class HttpUtil
{
private static final Logger LOG = new Logger(AbstractQueryResourceTestClient.class);
private static final Logger LOG = new Logger(HttpUtil.class);
private static final StatusResponseHandler RESPONSE_HANDLER = StatusResponseHandler.getInstance();
static final int NUM_RETRIES = 30;

View File

@ -21,6 +21,7 @@ package org.apache.druid.tests.security;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
@ -28,6 +29,7 @@ import com.google.inject.Inject;
import org.apache.calcite.avatica.AvaticaSqlException;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.guice.annotations.Client;
import org.apache.druid.guice.annotations.ExtensionPoint;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.java.util.common.logger.Logger;
@ -63,6 +65,7 @@ import java.util.Map;
import java.util.Properties;
import java.util.stream.Collectors;
@ExtensionPoint
public abstract class AbstractAuthConfigurationTest
{
private static final Logger LOG = new Logger(AbstractAuthConfigurationTest.class);
@ -105,6 +108,17 @@ public abstract class AbstractAuthConfigurationTest
)
);
protected static final List<ResourceAction> DATASOURCE_QUERY_CONTEXT_PERMISSIONS = ImmutableList.of(
new ResourceAction(
new Resource("auth_test", ResourceType.DATASOURCE),
Action.READ
),
new ResourceAction(
new Resource("auth_test_ctx", ResourceType.QUERY_CONTEXT),
Action.WRITE
)
);
/**
* create a ResourceAction set of permissions that can only read 'auth_test' + partial SYSTEM_TABLE, for Authorizer
* implementations which use ResourceAction pattern matching
@ -168,6 +182,36 @@ public abstract class AbstractAuthConfigurationTest
)
);
protected enum User
{
ADMIN("admin", "priest"),
DATASOURCE_ONLY_USER("datasourceOnlyUser", "helloworld"),
DATASOURCE_AND_CONTEXT_PARAMS_USER("datasourceAndContextParamsUser", "helloworld"),
DATASOURCE_AND_SYS_USER("datasourceAndSysUser", "helloworld"),
DATASOURCE_WITH_STATE_USER("datasourceWithStateUser", "helloworld"),
STATE_ONLY_USER("stateOnlyUser", "helloworld"),
INTERNAL_SYSTEM("druid_system", "warlock");
private final String name;
private final String password;
User(String name, String password)
{
this.name = name;
this.password = password;
}
public String getName()
{
return name;
}
public String getPassword()
{
return password;
}
}
protected List<Map<String, Object>> adminSegments;
protected List<Map<String, Object>> adminTasks;
protected List<Map<String, Object>> adminServers;
@ -186,15 +230,10 @@ public abstract class AbstractAuthConfigurationTest
@Inject
protected CoordinatorResourceTestClient coordinatorClient;
protected HttpClient adminClient;
protected HttpClient datasourceOnlyUserClient;
protected HttpClient datasourceAndSysUserClient;
protected HttpClient datasourceWithStateUserClient;
protected HttpClient stateOnlyUserClient;
protected HttpClient internalSystemClient;
protected Map<User, HttpClient> httpClients;
protected abstract void setupDatasourceOnlyUser() throws Exception;
protected abstract void setupDatasourceAndContextParamsUser() throws Exception;
protected abstract void setupDatasourceAndSysTableUser() throws Exception;
protected abstract void setupDatasourceAndSysAndStateUser() throws Exception;
protected abstract void setupSysTableAndStateOnlyUser() throws Exception;
@ -202,12 +241,25 @@ public abstract class AbstractAuthConfigurationTest
protected abstract String getAuthenticatorName();
protected abstract String getAuthorizerName();
protected abstract String getExpectedAvaticaAuthError();
protected abstract Properties getAvaticaConnectionProperties();
protected abstract Properties getAvaticaConnectionPropertiesFailure();
protected abstract String getExpectedAvaticaAuthzError();
/**
* Returns properties for the admin with an invalid password.
* Implementations can set any properties for authentication as they need.
*/
protected abstract Properties getAvaticaConnectionPropertiesForInvalidAdmin();
/**
* Returns properties for the given user.
* Implementations can set any properties for authentication as they need.
*
* @see User
*/
protected abstract Properties getAvaticaConnectionPropertiesForUser(User user);
@Test
public void test_systemSchemaAccess_admin() throws Exception
{
final HttpClient adminClient = getHttpClient(User.ADMIN);
// check that admin access works on all nodes
checkNodeAccess(adminClient);
@ -244,6 +296,7 @@ public abstract class AbstractAuthConfigurationTest
@Test
public void test_systemSchemaAccess_datasourceOnlyUser() throws Exception
{
final HttpClient datasourceOnlyUserClient = getHttpClient(User.DATASOURCE_ONLY_USER);
// check that we can access a datasource-permission restricted resource on the broker
HttpUtil.makeRequest(
datasourceOnlyUserClient,
@ -289,6 +342,7 @@ public abstract class AbstractAuthConfigurationTest
@Test
public void test_systemSchemaAccess_datasourceAndSysUser() throws Exception
{
final HttpClient datasourceAndSysUserClient = getHttpClient(User.DATASOURCE_AND_SYS_USER);
// check that we can access a datasource-permission restricted resource on the broker
HttpUtil.makeRequest(
datasourceAndSysUserClient,
@ -336,6 +390,7 @@ public abstract class AbstractAuthConfigurationTest
@Test
public void test_systemSchemaAccess_datasourceAndSysWithStateUser() throws Exception
{
final HttpClient datasourceWithStateUserClient = getHttpClient(User.DATASOURCE_WITH_STATE_USER);
// check that we can access a state-permission restricted resource on the broker
HttpUtil.makeRequest(
datasourceWithStateUserClient,
@ -384,6 +439,7 @@ public abstract class AbstractAuthConfigurationTest
@Test
public void test_systemSchemaAccess_stateOnlyUser() throws Exception
{
final HttpClient stateOnlyUserClient = getHttpClient(User.STATE_ONLY_USER);
HttpUtil.makeRequest(stateOnlyUserClient, HttpMethod.GET, config.getBrokerUrl() + "/status", null);
// as user that can only read STATE
@ -426,45 +482,89 @@ public abstract class AbstractAuthConfigurationTest
@Test
public void test_admin_loadStatus() throws Exception
{
checkLoadStatus(adminClient);
checkLoadStatus(getHttpClient(User.ADMIN));
}
@Test
public void test_admin_hasNodeAccess()
{
checkNodeAccess(adminClient);
checkNodeAccess(getHttpClient(User.ADMIN));
}
@Test
public void test_internalSystemUser_hasNodeAccess()
{
checkNodeAccess(internalSystemClient);
checkNodeAccess(getHttpClient(User.INTERNAL_SYSTEM));
}
@Test
public void test_avaticaQuery_broker()
{
testAvaticaQuery(getBrokerAvacticaUrl());
testAvaticaQuery(StringUtils.maybeRemoveTrailingSlash(getBrokerAvacticaUrl()));
final Properties properties = getAvaticaConnectionPropertiesForAdmin();
testAvaticaQuery(properties, getBrokerAvacticaUrl());
testAvaticaQuery(properties, StringUtils.maybeRemoveTrailingSlash(getBrokerAvacticaUrl()));
}
@Test
public void test_avaticaQuery_router()
{
testAvaticaQuery(getRouterAvacticaUrl());
testAvaticaQuery(StringUtils.maybeRemoveTrailingSlash(getRouterAvacticaUrl()));
final Properties properties = getAvaticaConnectionPropertiesForAdmin();
testAvaticaQuery(properties, getRouterAvacticaUrl());
testAvaticaQuery(properties, StringUtils.maybeRemoveTrailingSlash(getRouterAvacticaUrl()));
}
@Test
public void test_avaticaQueryAuthFailure_broker() throws Exception
{
testAvaticaAuthFailure(getBrokerAvacticaUrl());
final Properties properties = getAvaticaConnectionPropertiesForInvalidAdmin();
testAvaticaAuthFailure(properties, getBrokerAvacticaUrl());
}
@Test
public void test_avaticaQueryAuthFailure_router() throws Exception
{
testAvaticaAuthFailure(getRouterAvacticaUrl());
final Properties properties = getAvaticaConnectionPropertiesForInvalidAdmin();
testAvaticaAuthFailure(properties, getRouterAvacticaUrl());
}
@Test
public void test_avaticaQueryWithContext_datasourceOnlyUser_fail() throws Exception
{
final Properties properties = getAvaticaConnectionPropertiesForUser(User.DATASOURCE_ONLY_USER);
properties.setProperty("auth_test_ctx", "should-be-denied");
testAvaticaAuthzFailure(properties, getRouterAvacticaUrl());
}
@Test
public void test_avaticaQueryWithContext_datasourceAndContextParamsUser_succeed()
{
final Properties properties = getAvaticaConnectionPropertiesForUser(User.DATASOURCE_AND_CONTEXT_PARAMS_USER);
properties.setProperty("auth_test_ctx", "should-be-allowed");
testAvaticaQuery(properties, getRouterAvacticaUrl());
}
@Test
public void test_sqlQueryWithContext_datasourceOnlyUser_fail() throws Exception
{
final String query = "select count(*) from auth_test";
StatusResponseHolder responseHolder = makeSQLQueryRequest(
getHttpClient(User.DATASOURCE_ONLY_USER),
query,
ImmutableMap.of("auth_test_ctx", "should-be-denied"),
HttpResponseStatus.FORBIDDEN
);
}
@Test
public void test_sqlQueryWithContext_datasourceAndContextParamsUser_succeed() throws Exception
{
final String query = "select count(*) from auth_test";
StatusResponseHolder responseHolder = makeSQLQueryRequest(
getHttpClient(User.DATASOURCE_AND_CONTEXT_PARAMS_USER),
query,
ImmutableMap.of("auth_test_ctx", "should-be-allowed"),
HttpResponseStatus.OK
);
}
@Test
@ -497,10 +597,16 @@ public abstract class AbstractAuthConfigurationTest
verifyMaliciousUser();
}
protected HttpClient getHttpClient(User user)
{
return Preconditions.checkNotNull(httpClients.get(user), "http client for user[%s]", user.getName());
}
protected void setupHttpClientsAndUsers() throws Exception
{
setupHttpClients();
setupDatasourceOnlyUser();
setupDatasourceAndContextParamsUser();
setupDatasourceAndSysTableUser();
setupDatasourceAndSysAndStateUser();
setupSysTableAndStateOnlyUser();
@ -538,11 +644,15 @@ public abstract class AbstractAuthConfigurationTest
HttpUtil.makeRequest(client, HttpMethod.GET, config.getCoordinatorUrl() + "/druid/coordinator/v1/loadqueue", null);
}
protected void testAvaticaQuery(String url)
private Properties getAvaticaConnectionPropertiesForAdmin()
{
return getAvaticaConnectionPropertiesForUser(User.ADMIN);
}
protected void testAvaticaQuery(Properties connectionProperties, String url)
{
LOG.info("URL: " + url);
try {
Properties connectionProperties = getAvaticaConnectionProperties();
Connection connection = DriverManager.getConnection(url, connectionProperties);
Statement statement = connection.createStatement();
statement.setMaxRows(450);
@ -557,11 +667,21 @@ public abstract class AbstractAuthConfigurationTest
}
}
protected void testAvaticaAuthFailure(String url) throws Exception
protected void testAvaticaAuthFailure(Properties connectionProperties, String url) throws Exception
{
testAvaticaAuthFailure(connectionProperties, url, getExpectedAvaticaAuthError());
}
protected void testAvaticaAuthzFailure(Properties connectionProperties, String url) throws Exception
{
testAvaticaAuthFailure(connectionProperties, url, getExpectedAvaticaAuthzError());
}
protected void testAvaticaAuthFailure(Properties connectionProperties, String url, String expectedError)
throws Exception
{
LOG.info("URL: " + url);
try {
Properties connectionProperties = getAvaticaConnectionPropertiesFailure();
Connection connection = DriverManager.getConnection(url, connectionProperties);
Statement statement = connection.createStatement();
statement.setMaxRows(450);
@ -571,7 +691,7 @@ public abstract class AbstractAuthConfigurationTest
catch (AvaticaSqlException ase) {
Assert.assertEquals(
ase.getErrorMessage(),
getExpectedAvaticaAuthError()
expectedError
);
return;
}
@ -612,9 +732,20 @@ public abstract class AbstractAuthConfigurationTest
String query,
HttpResponseStatus expectedStatus
) throws Exception
{
return makeSQLQueryRequest(httpClient, query, ImmutableMap.of(), expectedStatus);
}
protected StatusResponseHolder makeSQLQueryRequest(
HttpClient httpClient,
String query,
Map<String, Object> context,
HttpResponseStatus expectedStatus
) throws Exception
{
Map<String, Object> queryMap = ImmutableMap.of(
"query", query
"query", query,
"context", context
);
return HttpUtil.makeRequestWithExpectedStatus(
httpClient,
@ -683,11 +814,7 @@ public abstract class AbstractAuthConfigurationTest
protected void verifyAdminOptionsRequest()
{
HttpClient adminClient = new CredentialedHttpClient(
new BasicCredentials("admin", "priest"),
httpClient
);
testOptionsRequests(adminClient);
testOptionsRequests(getHttpClient(User.ADMIN));
}
protected void verifyAuthenticationInvalidAuthNameFails()
@ -725,7 +852,7 @@ public abstract class AbstractAuthConfigurationTest
);
HttpUtil.makeRequestWithExpectedStatus(
adminClient,
getHttpClient(User.ADMIN),
HttpMethod.POST,
endpoint,
"SERIALIZED_DATA".getBytes(StandardCharsets.UTF_8),
@ -758,36 +885,23 @@ public abstract class AbstractAuthConfigurationTest
setupTestSpecificHttpClients();
}
protected void setupCommonHttpClients()
{
adminClient = new CredentialedHttpClient(
new BasicCredentials("admin", "priest"),
httpClient
);
httpClients = new HashMap<>();
for (User user : User.values()) {
httpClients.put(user, setupHttpClientForUser(user.getName(), user.getPassword()));
}
}
datasourceOnlyUserClient = new CredentialedHttpClient(
new BasicCredentials("datasourceOnlyUser", "helloworld"),
httpClient
);
datasourceAndSysUserClient = new CredentialedHttpClient(
new BasicCredentials("datasourceAndSysUser", "helloworld"),
httpClient
);
datasourceWithStateUserClient = new CredentialedHttpClient(
new BasicCredentials("datasourceWithStateUser", "helloworld"),
httpClient
);
stateOnlyUserClient = new CredentialedHttpClient(
new BasicCredentials("stateOnlyUser", "helloworld"),
httpClient
);
internalSystemClient = new CredentialedHttpClient(
new BasicCredentials("druid_system", "warlock"),
/**
* Creates a HttpClient with the given user credentials.
* Implementations can override this method to return a different implementation of HttpClient
* than the basic CredentialedHttpClient.
*/
protected HttpClient setupHttpClientForUser(String username, String password)
{
return new CredentialedHttpClient(
new BasicCredentials(username, password),
httpClient
);
}

View File

@ -48,6 +48,7 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
private static final String BASIC_AUTHORIZER = "basic";
private static final String EXPECTED_AVATICA_AUTH_ERROR = "Error while executing SQL \"SELECT * FROM INFORMATION_SCHEMA.COLUMNS\": Remote driver error: QueryInterruptedException: User metadata store authentication failed. -> BasicSecurityAuthenticationException: User metadata store authentication failed.";
private static final String EXPECTED_AVATICA_AUTHZ_ERROR = "Error while executing SQL \"SELECT * FROM INFORMATION_SCHEMA.COLUMNS\": Remote driver error: RuntimeException: org.apache.druid.server.security.ForbiddenException: Allowed:false, Message: -> ForbiddenException: Allowed:false, Message:";
private HttpClient druid99;
@ -73,7 +74,7 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
protected void setupDatasourceOnlyUser() throws Exception
{
createUserAndRoleWithPermissions(
adminClient,
getHttpClient(User.ADMIN),
"datasourceOnlyUser",
"helloworld",
"datasourceOnlyRole",
@ -81,11 +82,23 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
);
}
@Override
protected void setupDatasourceAndContextParamsUser() throws Exception
{
createUserAndRoleWithPermissions(
getHttpClient(User.ADMIN),
"datasourceAndContextParamsUser",
"helloworld",
"datasourceAndContextParamsRole",
DATASOURCE_QUERY_CONTEXT_PERMISSIONS
);
}
@Override
protected void setupDatasourceAndSysTableUser() throws Exception
{
createUserAndRoleWithPermissions(
adminClient,
getHttpClient(User.ADMIN),
"datasourceAndSysUser",
"helloworld",
"datasourceAndSysRole",
@ -97,7 +110,7 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
protected void setupDatasourceAndSysAndStateUser() throws Exception
{
createUserAndRoleWithPermissions(
adminClient,
getHttpClient(User.ADMIN),
"datasourceWithStateUser",
"helloworld",
"datasourceWithStateRole",
@ -109,7 +122,7 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
protected void setupSysTableAndStateOnlyUser() throws Exception
{
createUserAndRoleWithPermissions(
adminClient,
getHttpClient(User.ADMIN),
"stateOnlyUser",
"helloworld",
"stateOnlyRole",
@ -122,7 +135,7 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
{
// create a new user+role that can read /status
createUserAndRoleWithPermissions(
adminClient,
getHttpClient(User.ADMIN),
"druid",
"helloworld",
"druidrole",
@ -132,14 +145,14 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
// create 100 users
for (int i = 0; i < 100; i++) {
HttpUtil.makeRequest(
adminClient,
getHttpClient(User.ADMIN),
HttpMethod.POST,
config.getCoordinatorUrl() + "/druid-ext/basic-security/authentication/db/basic/users/druid" + i,
null
);
HttpUtil.makeRequest(
adminClient,
getHttpClient(User.ADMIN),
HttpMethod.POST,
config.getCoordinatorUrl() + "/druid-ext/basic-security/authorization/db/basic/users/druid" + i,
null
@ -150,14 +163,14 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
// setup the last of 100 users and check that it works
HttpUtil.makeRequest(
adminClient,
getHttpClient(User.ADMIN),
HttpMethod.POST,
config.getCoordinatorUrl() + "/druid-ext/basic-security/authentication/db/basic/users/druid99/credentials",
jsonMapper.writeValueAsBytes(new BasicAuthenticatorCredentialUpdate("helloworld", 5000))
);
HttpUtil.makeRequest(
adminClient,
getHttpClient(User.ADMIN),
HttpMethod.POST,
config.getCoordinatorUrl() + "/druid-ext/basic-security/authorization/db/basic/users/druid99/roles/druidrole",
null
@ -188,20 +201,26 @@ public class ITBasicAuthConfigurationTest extends AbstractAuthConfigurationTest
}
@Override
protected Properties getAvaticaConnectionProperties()
protected String getExpectedAvaticaAuthzError()
{
return EXPECTED_AVATICA_AUTHZ_ERROR;
}
@Override
protected Properties getAvaticaConnectionPropertiesForInvalidAdmin()
{
Properties connectionProperties = new Properties();
connectionProperties.setProperty("user", "admin");
connectionProperties.setProperty("password", "priest");
connectionProperties.setProperty("password", "invalid_password");
return connectionProperties;
}
@Override
protected Properties getAvaticaConnectionPropertiesFailure()
protected Properties getAvaticaConnectionPropertiesForUser(User user)
{
Properties connectionProperties = new Properties();
connectionProperties.setProperty("user", "admin");
connectionProperties.setProperty("password", "wrongpassword");
connectionProperties.setProperty("user", user.getName());
connectionProperties.setProperty("password", user.getPassword());
return connectionProperties;
}

View File

@ -54,6 +54,7 @@ public class ITBasicAuthLdapConfigurationTest extends AbstractAuthConfigurationT
private static final String LDAP_AUTHORIZER = "ldapauth";
private static final String EXPECTED_AVATICA_AUTH_ERROR = "Error while executing SQL \"SELECT * FROM INFORMATION_SCHEMA.COLUMNS\": Remote driver error: QueryInterruptedException: User LDAP authentication failed. -> BasicSecurityAuthenticationException: User LDAP authentication failed.";
private static final String EXPECTED_AVATICA_AUTHZ_ERROR = "Error while executing SQL \"SELECT * FROM INFORMATION_SCHEMA.COLUMNS\": Remote driver error: RuntimeException: org.apache.druid.server.security.ForbiddenException: Allowed:false, Message: -> ForbiddenException: Allowed:false, Message:";
@Inject
IntegrationTestingConfig config;
@ -80,7 +81,7 @@ public class ITBasicAuthLdapConfigurationTest extends AbstractAuthConfigurationT
@Test
public void test_systemSchemaAccess_stateOnlyNoLdapGroupUser() throws Exception
{
HttpUtil.makeRequest(stateOnlyUserClient, HttpMethod.GET, config.getBrokerUrl() + "/status", null);
HttpUtil.makeRequest(getHttpClient(User.STATE_ONLY_USER), HttpMethod.GET, config.getBrokerUrl() + "/status", null);
// as user that can only read STATE
LOG.info("Checking sys.segments query as stateOnlyNoLdapGroupUser...");
@ -128,6 +129,15 @@ public class ITBasicAuthLdapConfigurationTest extends AbstractAuthConfigurationT
);
}
@Override
protected void setupDatasourceAndContextParamsUser() throws Exception
{
createRoleWithPermissionsAndGroupMapping(
"datasourceAndContextParamsGroup",
ImmutableMap.of("datasourceAndContextParamsRole", DATASOURCE_QUERY_CONTEXT_PERMISSIONS)
);
}
@Override
protected void setupDatasourceAndSysTableUser() throws Exception
{
@ -196,20 +206,26 @@ public class ITBasicAuthLdapConfigurationTest extends AbstractAuthConfigurationT
}
@Override
protected Properties getAvaticaConnectionProperties()
protected String getExpectedAvaticaAuthzError()
{
return EXPECTED_AVATICA_AUTHZ_ERROR;
}
@Override
protected Properties getAvaticaConnectionPropertiesForInvalidAdmin()
{
Properties connectionProperties = new Properties();
connectionProperties.setProperty("user", "admin");
connectionProperties.setProperty("password", "priest");
connectionProperties.setProperty("password", "invalid_password");
return connectionProperties;
}
@Override
protected Properties getAvaticaConnectionPropertiesFailure()
protected Properties getAvaticaConnectionPropertiesForUser(User user)
{
Properties connectionProperties = new Properties();
connectionProperties.setProperty("user", "admin");
connectionProperties.setProperty("password", "wrongpassword");
connectionProperties.setProperty("user", user.getName());
connectionProperties.setProperty("password", user.getPassword());
return connectionProperties;
}
@ -218,6 +234,7 @@ public class ITBasicAuthLdapConfigurationTest extends AbstractAuthConfigurationT
Map<String, List<ResourceAction>> roleTopermissions
) throws Exception
{
final HttpClient adminClient = getHttpClient(User.ADMIN);
roleTopermissions.keySet().forEach(role -> HttpUtil.makeRequest(
adminClient,
HttpMethod.POST,
@ -269,6 +286,7 @@ public class ITBasicAuthLdapConfigurationTest extends AbstractAuthConfigurationT
String role
)
{
final HttpClient adminClient = getHttpClient(User.ADMIN);
HttpUtil.makeRequest(
adminClient,
HttpMethod.POST,

View File

@ -58,7 +58,7 @@ public abstract class BaseQuery<T> implements Query<T>
public static final String SQL_QUERY_ID = "sqlQueryId";
private final DataSource dataSource;
private final boolean descending;
private final Map<String, Object> context;
private final QueryContext context;
private final QuerySegmentSpec querySegmentSpec;
private volatile Duration duration;
private final Granularity granularity;
@ -86,7 +86,7 @@ public abstract class BaseQuery<T> implements Query<T>
Preconditions.checkNotNull(granularity, "Must specify a granularity");
this.dataSource = dataSource;
this.context = context;
this.context = new QueryContext(context);
this.querySegmentSpec = querySegmentSpec;
this.descending = descending;
this.granularity = granularity;
@ -166,6 +166,12 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
@JsonProperty
public Map<String, Object> getContext()
{
return context.getMergedParams();
}
@Override
public QueryContext getQueryContext()
{
return context;
}
@ -173,7 +179,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
public <ContextType> ContextType getContextValue(String key)
{
return context == null ? null : (ContextType) context.get(key);
return (ContextType) context.get(key);
}
@Override
@ -186,7 +192,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
public boolean getContextBoolean(String key, boolean defaultValue)
{
return QueryContexts.parseBoolean(this, key, defaultValue);
return context.getAsBoolean(key, defaultValue);
}
/**
@ -230,7 +236,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
public String getId()
{
return (String) getContextValue(QUERY_ID);
return context.getAsString(QUERY_ID);
}
@Override
@ -243,7 +249,7 @@ public abstract class BaseQuery<T> implements Query<T>
@Override
public String getSubQueryId()
{
return (String) getContextValue(SUB_QUERY_ID);
return context.getAsString(SUB_QUERY_ID);
}
@Override
@ -252,13 +258,6 @@ public abstract class BaseQuery<T> implements Query<T>
return withOverriddenContext(ImmutableMap.of(QUERY_ID, id));
}
@Nullable
@Override
public String getSqlQueryId()
{
return (String) getContextValue(SQL_QUERY_ID);
}
@Override
public Query<T> withSqlQueryId(String sqlQueryId)
{

View File

@ -61,7 +61,6 @@ import java.util.UUID;
@JsonSubTypes.Type(name = Query.SELECT, value = SelectQuery.class),
@JsonSubTypes.Type(name = Query.TOPN, value = TopNQuery.class),
@JsonSubTypes.Type(name = Query.DATASOURCE_METADATA, value = DataSourceMetadataQuery.class)
})
public interface Query<T>
{
@ -95,8 +94,24 @@ public interface Query<T>
DateTimeZone getTimezone();
/**
* Use {@link #getQueryContext()} instead.
*/
@Deprecated
Map<String, Object> getContext();
/**
* Returns QueryContext for this query.
*
* Note for query context serialization and deserialization.
* Currently, once a query is serialized, its queryContext can be different from the original queryContext
* after the query is deserialized back. If the queryContext has any {@link QueryContext#defaultParams} or
* {@link QueryContext#systemParams} in it, those will be found in {@link QueryContext#userParams}
* after it is deserialized. This is because {@link BaseQuery#getContext()} uses
* {@link QueryContext#getMergedParams()} for serialization, and queries accept a map for deserialization.
*/
QueryContext getQueryContext();
<ContextType> ContextType getContextValue(String key);
<ContextType> ContextType getContextValue(String key, ContextType defaultValue);
@ -159,7 +174,7 @@ public interface Query<T>
@Nullable
default String getSqlQueryId()
{
return null;
return getContextValue(BaseQuery.SQL_QUERY_ID);
}
/**

View File

@ -0,0 +1,247 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Numbers;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;
/**
* Holder for query context parameters. There are 3 ways to set context params today.
*
* - Default parameters. These are set mostly via {@link DefaultQueryConfig#context}.
* Auto-generated queryId or sqlQueryId are also set as default parameters. These default parameters can
* be overridden by user or system parameters.
* - User parameters. These are the params set by the user. User params override default parameters but
* are overridden by system parameters.
* - System parameters. These are the params set by the Druid query engine for internal use only.
*
* You can use {@code getX} methods or {@link #getMergedParams()} to compute the context params
* merging 3 types of params above.
*
* Currently, this class is mainly used for query context parameter authorization,
* such as HTTP query endpoints or JDBC endpoint. Its usage can be expanded in the future if we
* want to track user parameters and separate them from others during query processing.
*/
public class QueryContext
{
private final Map<String, Object> defaultParams;
private final Map<String, Object> userParams;
private final Map<String, Object> systemParams;
/**
* Cache of params merged.
*/
@Nullable
private Map<String, Object> mergedParams;
public QueryContext()
{
this(null);
}
public QueryContext(@Nullable Map<String, Object> userParams)
{
this.defaultParams = new TreeMap<>();
this.userParams = userParams == null ? new TreeMap<>() : new TreeMap<>(userParams);
this.systemParams = new TreeMap<>();
invalidateMergedParams();
}
private void invalidateMergedParams()
{
this.mergedParams = null;
}
public boolean isEmpty()
{
return defaultParams.isEmpty() && userParams.isEmpty() && systemParams.isEmpty();
}
public void addDefaultParam(String key, Object val)
{
invalidateMergedParams();
defaultParams.put(key, val);
}
public void addDefaultParams(Map<String, Object> defaultParams)
{
invalidateMergedParams();
this.defaultParams.putAll(defaultParams);
}
public void addSystemParam(String key, Object val)
{
invalidateMergedParams();
this.systemParams.put(key, val);
}
public Object removeUserParam(String key)
{
invalidateMergedParams();
return userParams.remove(key);
}
/**
* Returns only the context parameters the user sets.
* The returned map does not include the parameters that have been removed via {@link #removeUserParam}.
*
* Callers should use {@code getX} methods or {@link #getMergedParams()} instead to use the whole context params.
*/
public Map<String, Object> getUserParams()
{
return userParams;
}
public boolean isDebug()
{
return getAsBoolean(QueryContexts.ENABLE_DEBUG, QueryContexts.DEFAULT_ENABLE_DEBUG);
}
public boolean isEnableJoinLeftScanDirect()
{
return getAsBoolean(
QueryContexts.SQL_JOIN_LEFT_SCAN_DIRECT,
QueryContexts.DEFAULT_ENABLE_SQL_JOIN_LEFT_SCAN_DIRECT
);
}
@SuppressWarnings("unused")
public boolean containsKey(String key)
{
return get(key) != null;
}
@Nullable
public Object get(String key)
{
Object val = systemParams.get(key);
if (val != null) {
return val;
}
val = userParams.get(key);
return val == null ? defaultParams.get(key) : val;
}
@SuppressWarnings("unused")
public Object getOrDefault(String key, Object defaultValue)
{
final Object val = get(key);
return val == null ? defaultValue : val;
}
@Nullable
public String getAsString(String key)
{
return (String) get(key);
}
public boolean getAsBoolean(
final String parameter,
final boolean defaultValue
)
{
final Object value = get(parameter);
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Boolean.parseBoolean((String) value);
} else if (value instanceof Boolean) {
return (Boolean) value;
} else {
throw new IAE("Expected parameter[%s] to be boolean", parameter);
}
}
public int getAsInt(
final String parameter,
final int defaultValue
)
{
final Object value = get(parameter);
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Numbers.parseInt(value);
} else if (value instanceof Number) {
return ((Number) value).intValue();
} else {
throw new IAE("Expected parameter[%s] to be integer", parameter);
}
}
public long getAsLong(final String parameter, final long defaultValue)
{
final Object value = get(parameter);
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Numbers.parseLong(value);
} else if (value instanceof Number) {
return ((Number) value).longValue();
} else {
throw new IAE("Expected parameter[%s] to be long", parameter);
}
}
public Map<String, Object> getMergedParams()
{
if (mergedParams == null) {
final Map<String, Object> merged = new TreeMap<>(defaultParams);
merged.putAll(userParams);
merged.putAll(systemParams);
mergedParams = Collections.unmodifiableMap(merged);
}
return mergedParams;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
QueryContext context = (QueryContext) o;
return getMergedParams().equals(context.getMergedParams());
}
@Override
public int hashCode()
{
return Objects.hash(getMergedParams());
}
@Override
public String toString()
{
return "QueryContext{" +
"defaultParams=" + defaultParams +
", userParams=" + userParams +
", systemParams=" + systemParams +
'}';
}
}

View File

@ -24,6 +24,7 @@ import com.google.common.collect.Ordering;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunner;
import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.query.filter.DimFilter;
@ -109,6 +110,12 @@ public class SelectQuery implements Query<Object>
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public QueryContext getQueryContext()
{
throw new RuntimeException(REMOVED_ERROR_MESSAGE);
}
@Override
public <ContextType> ContextType getContextValue(String key)
{

View File

@ -0,0 +1,235 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.query;
import com.google.common.collect.ImmutableMap;
import nl.jqno.equalsverifier.EqualsVerifier;
import nl.jqno.equalsverifier.Warning;
import org.junit.Assert;
import org.junit.Test;
public class QueryContextTest
{
@Test
public void testEquals()
{
EqualsVerifier.configure()
.suppress(Warning.NONFINAL_FIELDS, Warning.ALL_FIELDS_SHOULD_BE_USED)
.usingGetClass()
.forClass(QueryContext.class)
.withNonnullFields("defaultParams", "userParams", "systemParams")
.verify();
}
@Test
public void testEmptyParam()
{
final QueryContext context = new QueryContext();
Assert.assertEquals(ImmutableMap.of(), context.getMergedParams());
}
@Test
public void testIsEmpty()
{
Assert.assertTrue(new QueryContext().isEmpty());
Assert.assertFalse(new QueryContext(ImmutableMap.of("k", "v")).isEmpty());
QueryContext context = new QueryContext();
context.addDefaultParam("k", "v");
Assert.assertFalse(context.isEmpty());
context = new QueryContext();
context.addSystemParam("k", "v");
Assert.assertFalse(context.isEmpty());
}
@Test
public void testGetString()
{
final QueryContext context = new QueryContext(
ImmutableMap.of("key", "val")
);
Assert.assertEquals("val", context.get("key"));
Assert.assertEquals("val", context.getAsString("key"));
Assert.assertNull(context.getAsString("non-exist"));
}
@Test
public void testGetBoolean()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"key1", "true",
"key2", true
)
);
Assert.assertTrue(context.getAsBoolean("key1", false));
Assert.assertTrue(context.getAsBoolean("key2", false));
Assert.assertFalse(context.getAsBoolean("non-exist", false));
}
@Test
public void testGetInt()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"key1", "100",
"key2", 100
)
);
Assert.assertEquals(100, context.getAsInt("key1", 0));
Assert.assertEquals(100, context.getAsInt("key2", 0));
Assert.assertEquals(0, context.getAsInt("non-exist", 0));
}
@Test
public void testGetLong()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"key1", "100",
"key2", 100
)
);
Assert.assertEquals(100L, context.getAsLong("key1", 0));
Assert.assertEquals(100L, context.getAsLong("key2", 0));
Assert.assertEquals(0L, context.getAsLong("non-exist", 0));
}
@Test
public void testAddSystemParamOverrideUserParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addSystemParam("sys1", "sysVal1");
context.addSystemParam("conflict", "sysVal2");
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
),
context.getUserParams()
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"sys1", "sysVal1",
"conflict", "sysVal2"
),
context.getMergedParams()
);
}
@Test
public void testUserParamOverrideDefaultParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1"
)
);
context.addDefaultParam("conflict", "defaultVal2");
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
),
context.getUserParams()
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "userVal2"
),
context.getMergedParams()
);
}
@Test
public void testRemoveUserParam()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "userVal2"
),
context.getMergedParams()
);
Assert.assertEquals("userVal2", context.removeUserParam("conflict"));
Assert.assertEquals(
ImmutableMap.of(
"user1", "userVal1",
"default1", "defaultVal1",
"conflict", "defaultVal2"
),
context.getMergedParams()
);
}
@Test
public void testGetMergedParams()
{
final QueryContext context = new QueryContext(
ImmutableMap.of(
"user1", "userVal1",
"conflict", "userVal2"
)
);
context.addDefaultParams(
ImmutableMap.of(
"default1", "defaultVal1",
"conflict", "defaultVal2"
)
);
Assert.assertSame(context.getMergedParams(), context.getMergedParams());
}
}

View File

@ -57,7 +57,7 @@ public class ScanQuerySpecTest
+ "\"limit\":3,"
+ "\"filter\":null,"
+ "\"columns\":[\"market\",\"quality\",\"index\"],"
+ "\"context\":null,"
+ "\"context\":{},"
+ "\"descending\":false,"
+ "\"granularity\":{\"type\":\"all\"}}";
@ -96,7 +96,7 @@ public class ScanQuerySpecTest
+ "\"order\":\"ascending\","
+ "\"filter\":null,"
+ "\"columns\":[\"market\",\"quality\",\"index\",\"__time\"],"
+ "\"context\":null,"
+ "\"context\":{},"
+ "\"descending\":false,"
+ "\"granularity\":{\"type\":\"all\"}}";
@ -139,7 +139,7 @@ public class ScanQuerySpecTest
+ "\"orderBy\":[{\"columnName\":\"quality\",\"order\":\"ascending\"}],"
+ "\"filter\":null,"
+ "\"columns\":[\"market\",\"quality\",\"index\",\"__time\"],"
+ "\"context\":null,"
+ "\"context\":{},"
+ "\"descending\":false,"
+ "\"granularity\":{\"type\":\"all\"}}";

View File

@ -19,8 +19,8 @@
package org.apache.druid.server;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Iterables;
import org.apache.druid.client.DirectDruidClient;
import org.apache.druid.java.util.common.DateTimes;
@ -45,14 +45,22 @@ import org.apache.druid.query.QueryTimeoutException;
import org.apache.druid.query.QueryToolChest;
import org.apache.druid.query.QueryToolChestWarehouse;
import org.apache.druid.query.context.ResponseContext;
import org.apache.druid.server.QueryResource.ResourceIOReaderWriter;
import org.apache.druid.server.log.RequestLogger;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthorizationUtils;
import org.apache.druid.server.security.AuthorizerMapper;
import org.apache.druid.server.security.Resource;
import org.apache.druid.server.security.ResourceAction;
import org.apache.druid.server.security.ResourceType;
import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.UUID;
@ -82,13 +90,16 @@ public class QueryLifecycle
private final RequestLogger requestLogger;
private final AuthorizerMapper authorizerMapper;
private final DefaultQueryConfig defaultQueryConfig;
private final AuthConfig authConfig;
private final long startMs;
private final long startNs;
private State state = State.NEW;
private AuthenticationResult authenticationResult;
private QueryToolChest toolChest;
private Query baseQuery;
@MonotonicNonNull
private Query<?> baseQuery;
public QueryLifecycle(
final QueryToolChestWarehouse warehouse,
@ -98,6 +109,7 @@ public class QueryLifecycle
final RequestLogger requestLogger,
final AuthorizerMapper authorizerMapper,
final DefaultQueryConfig defaultQueryConfig,
final AuthConfig authConfig,
final long startMs,
final long startNs
)
@ -109,6 +121,7 @@ public class QueryLifecycle
this.requestLogger = requestLogger;
this.authorizerMapper = authorizerMapper;
this.defaultQueryConfig = defaultQueryConfig;
this.authConfig = authConfig;
this.startMs = startMs;
this.startNs = startNs;
}
@ -173,19 +186,10 @@ public class QueryLifecycle
{
transition(State.NEW, State.INITIALIZED);
String queryId = baseQuery.getId();
if (Strings.isNullOrEmpty(queryId)) {
queryId = UUID.randomUUID().toString();
}
baseQuery.getQueryContext().addDefaultParam(BaseQuery.QUERY_ID, UUID.randomUUID().toString());
baseQuery.getQueryContext().addDefaultParams(defaultQueryConfig.getContext());
Map<String, Object> mergedUserAndConfigContext;
if (baseQuery.getContext() != null) {
mergedUserAndConfigContext = BaseQuery.computeOverriddenContext(defaultQueryConfig.getContext(), baseQuery.getContext());
} else {
mergedUserAndConfigContext = defaultQueryConfig.getContext();
}
this.baseQuery = baseQuery.withOverriddenContext(mergedUserAndConfigContext).withId(queryId);
this.baseQuery = baseQuery;
this.toolChest = warehouse.getToolChest(baseQuery);
}
@ -200,14 +204,23 @@ public class QueryLifecycle
public Access authorize(HttpServletRequest req)
{
transition(State.INITIALIZED, State.AUTHORIZING);
final Iterable<ResourceAction> resourcesToAuthorize = Iterables.concat(
Iterables.transform(
baseQuery.getDataSource().getTableNames(),
AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR
),
authConfig.authorizeQueryContextParams()
? Iterables.transform(
baseQuery.getQueryContext().getUserParams().keySet(),
contextParam -> new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE)
)
: Collections.emptyList()
);
return doAuthorize(
AuthorizationUtils.authenticationResultFromRequest(req),
AuthorizationUtils.authorizeAllResourceActions(
req,
Iterables.transform(
baseQuery.getDataSource().getTableNames(),
AuthorizationUtils.DATASOURCE_READ_RA_GENERATOR
),
resourcesToAuthorize,
authorizerMapper
)
);
@ -343,11 +356,44 @@ public class QueryLifecycle
}
}
public Query getQuery()
@Nullable
public Query<?> getQuery()
{
return baseQuery;
}
public String getQueryId()
{
return baseQuery.getId();
}
public String threadName(String currThreadName)
{
return StringUtils.format(
"%s[%s_%s_%s]",
currThreadName,
baseQuery.getType(),
baseQuery.getDataSource().getTableNames(),
getQueryId()
);
}
private boolean isSerializeDateTimeAsLong()
{
final boolean shouldFinalize = QueryContexts.isFinalize(baseQuery, true);
return QueryContexts.isSerializeDateTimeAsLong(baseQuery, false)
|| (!shouldFinalize && QueryContexts.isSerializeDateTimeAsLongInner(baseQuery, false));
}
public ObjectWriter newOutputWriter(ResourceIOReaderWriter ioReaderWriter)
{
return ioReaderWriter.getResponseWriter().newOutputWriter(
getToolChest(),
baseQuery,
isSerializeDateTimeAsLong()
);
}
public QueryToolChest getToolChest()
{
if (state.compareTo(State.INITIALIZED) < 0) {

View File

@ -41,6 +41,7 @@ public class QueryLifecycleFactory
private final RequestLogger requestLogger;
private final AuthorizerMapper authorizerMapper;
private final DefaultQueryConfig defaultQueryConfig;
private final AuthConfig authConfig;
@Inject
public QueryLifecycleFactory(
@ -61,6 +62,7 @@ public class QueryLifecycleFactory
this.requestLogger = requestLogger;
this.authorizerMapper = authorizerMapper;
this.defaultQueryConfig = queryConfigSupplier.get();
this.authConfig = authConfig;
}
public QueryLifecycle factorize()
@ -73,6 +75,7 @@ public class QueryLifecycleFactory
requestLogger,
authorizerMapper,
defaultQueryConfig,
authConfig,
System.currentTimeMillis(),
System.nanoTime()
);

View File

@ -29,7 +29,6 @@ import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.io.CountingOutputStream;
import com.google.inject.Inject;
@ -47,7 +46,6 @@ import org.apache.druid.query.BadJsonQueryException;
import org.apache.druid.query.BadQueryException;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryException;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException;
@ -80,7 +78,6 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.StreamingOutput;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@ -109,7 +106,6 @@ public class QueryResource implements QueryCountStatsProvider
protected final ObjectMapper serializeDateTimeAsLongJsonMapper;
protected final ObjectMapper serializeDateTimeAsLongSmileMapper;
protected final QueryScheduler queryScheduler;
protected final AuthConfig authConfig;
protected final AuthorizerMapper authorizerMapper;
private final ResponseContextConfig responseContextConfig;
@ -138,7 +134,6 @@ public class QueryResource implements QueryCountStatsProvider
this.serializeDateTimeAsLongJsonMapper = serializeDataTimeAsLong(jsonMapper);
this.serializeDateTimeAsLongSmileMapper = serializeDataTimeAsLong(smileMapper);
this.queryScheduler = queryScheduler;
this.authConfig = authConfig;
this.authorizerMapper = authorizerMapper;
this.responseContextConfig = responseContextConfig;
this.selfNode = selfNode;
@ -184,28 +179,19 @@ public class QueryResource implements QueryCountStatsProvider
) throws IOException
{
final QueryLifecycle queryLifecycle = queryLifecycleFactory.factorize();
Query<?> query = null;
final ResourceIOReaderWriter ioReaderWriter = createResourceIOReaderWriter(req, pretty != null);
final String currThreadName = Thread.currentThread().getName();
try {
queryLifecycle.initialize(readQuery(req, in, ioReaderWriter));
query = queryLifecycle.getQuery();
final String queryId = query.getId();
final String queryThreadName = StringUtils.format(
"%s[%s_%s_%s]",
currThreadName,
query.getType(),
query.getDataSource().getTableNames(),
queryId
);
final Query<?> query = readQuery(req, in, ioReaderWriter);
queryLifecycle.initialize(query);
final String queryId = queryLifecycle.getQueryId();
final String queryThreadName = queryLifecycle.threadName(currThreadName);
Thread.currentThread().setName(queryThreadName);
if (log.isDebugEnabled()) {
log.debug("Got query [%s]", query);
log.debug("Got query [%s]", queryLifecycle.getQuery());
}
final Access authResult = queryLifecycle.authorize(req);
@ -227,16 +213,7 @@ public class QueryResource implements QueryCountStatsProvider
final Yielder<?> yielder = Yielders.each(results);
try {
boolean shouldFinalize = QueryContexts.isFinalize(query, true);
boolean serializeDateTimeAsLong =
QueryContexts.isSerializeDateTimeAsLong(query, false)
|| (!shouldFinalize && QueryContexts.isSerializeDateTimeAsLongInner(query, false));
final ObjectWriter jsonWriter = ioReaderWriter.getResponseWriter().newOutputWriter(
queryLifecycle.getToolChest(),
queryLifecycle.getQuery(),
serializeDateTimeAsLong
);
final ObjectWriter jsonWriter = queryLifecycle.newOutputWriter(ioReaderWriter);
Response.ResponseBuilder responseBuilder = Response
.ok(
@ -364,7 +341,12 @@ public class QueryResource implements QueryCountStatsProvider
log.noStackTrace()
.makeAlert(e, "Exception handling request")
.addData("query", query != null ? jsonMapper.writeValueAsString(query) : "unparseable query")
.addData(
"query",
queryLifecycle.getQuery() != null
? jsonMapper.writeValueAsString(queryLifecycle.getQuery())
: "unparseable query"
)
.addData("peer", req.getRemoteAddr())
.emit();
@ -381,7 +363,7 @@ public class QueryResource implements QueryCountStatsProvider
final ResourceIOReaderWriter ioReaderWriter
) throws IOException
{
Query baseQuery;
final Query<?> baseQuery;
try {
baseQuery = ioReaderWriter.getRequestMapper().readValue(in, Query.class);
}
@ -391,9 +373,7 @@ public class QueryResource implements QueryCountStatsProvider
String prevEtag = getPreviousEtag(req);
if (prevEtag != null) {
baseQuery = baseQuery.withOverriddenContext(
ImmutableMap.of(HEADER_IF_NONE_MATCH, prevEtag)
);
baseQuery.getQueryContext().addSystemParam(HEADER_IF_NONE_MATCH, prevEtag);
}
return baseQuery;

View File

@ -48,7 +48,7 @@ public class AuthConfig
public AuthConfig()
{
this(null, null, null, false);
this(null, null, null, false, false);
}
@JsonCreator
@ -56,20 +56,22 @@ public class AuthConfig
@JsonProperty("authenticatorChain") List<String> authenticatorChain,
@JsonProperty("authorizers") List<String> authorizers,
@JsonProperty("unsecuredPaths") List<String> unsecuredPaths,
@JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions
@JsonProperty("allowUnauthenticatedHttpOptions") boolean allowUnauthenticatedHttpOptions,
@JsonProperty("authorizeQueryContextParams") boolean authorizeQueryContextParams
)
{
this.authenticatorChain = authenticatorChain;
this.authorizers = authorizers;
this.unsecuredPaths = unsecuredPaths == null ? Collections.emptyList() : unsecuredPaths;
this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions;
this.authorizeQueryContextParams = authorizeQueryContextParams;
}
@JsonProperty
private final List<String> authenticatorChain;
@JsonProperty
private List<String> authorizers;
private final List<String> authorizers;
@JsonProperty
private final List<String> unsecuredPaths;
@ -77,6 +79,9 @@ public class AuthConfig
@JsonProperty
private final boolean allowUnauthenticatedHttpOptions;
@JsonProperty
private final boolean authorizeQueryContextParams;
public List<String> getAuthenticatorChain()
{
return authenticatorChain;
@ -97,6 +102,11 @@ public class AuthConfig
return allowUnauthenticatedHttpOptions;
}
public boolean authorizeQueryContextParams()
{
return authorizeQueryContextParams;
}
@Override
public boolean equals(Object o)
{
@ -107,20 +117,22 @@ public class AuthConfig
return false;
}
AuthConfig that = (AuthConfig) o;
return isAllowUnauthenticatedHttpOptions() == that.isAllowUnauthenticatedHttpOptions() &&
Objects.equals(getAuthenticatorChain(), that.getAuthenticatorChain()) &&
Objects.equals(getAuthorizers(), that.getAuthorizers()) &&
Objects.equals(getUnsecuredPaths(), that.getUnsecuredPaths());
return allowUnauthenticatedHttpOptions == that.allowUnauthenticatedHttpOptions
&& authorizeQueryContextParams == that.authorizeQueryContextParams
&& Objects.equals(authenticatorChain, that.authenticatorChain)
&& Objects.equals(authorizers, that.authorizers)
&& Objects.equals(unsecuredPaths, that.unsecuredPaths);
}
@Override
public int hashCode()
{
return Objects.hash(
getAuthenticatorChain(),
getAuthorizers(),
getUnsecuredPaths(),
isAllowUnauthenticatedHttpOptions()
authenticatorChain,
authorizers,
unsecuredPaths,
allowUnauthenticatedHttpOptions,
authorizeQueryContextParams
);
}
@ -132,6 +144,65 @@ public class AuthConfig
", authorizers=" + authorizers +
", unsecuredPaths=" + unsecuredPaths +
", allowUnauthenticatedHttpOptions=" + allowUnauthenticatedHttpOptions +
", enableQueryContextAuthorization=" + authorizeQueryContextParams +
'}';
}
public static Builder newBuilder()
{
return new Builder();
}
/**
* AuthConfig object is created via Jackson in production. This builder is for easier code maintenance in unit tests.
*/
public static class Builder
{
private List<String> authenticatorChain;
private List<String> authorizers;
private List<String> unsecuredPaths;
private boolean allowUnauthenticatedHttpOptions;
private boolean authorizeQueryContextParams;
public Builder setAuthenticatorChain(List<String> authenticatorChain)
{
this.authenticatorChain = authenticatorChain;
return this;
}
public Builder setAuthorizers(List<String> authorizers)
{
this.authorizers = authorizers;
return this;
}
public Builder setUnsecuredPaths(List<String> unsecuredPaths)
{
this.unsecuredPaths = unsecuredPaths;
return this;
}
public Builder setAllowUnauthenticatedHttpOptions(boolean allowUnauthenticatedHttpOptions)
{
this.allowUnauthenticatedHttpOptions = allowUnauthenticatedHttpOptions;
return this;
}
public Builder setAuthorizeQueryContextParams(boolean authorizeQueryContextParams)
{
this.authorizeQueryContextParams = authorizeQueryContextParams;
return this;
}
public AuthConfig build()
{
return new AuthConfig(
authenticatorChain,
authorizers,
unsecuredPaths,
allowUnauthenticatedHttpOptions,
authorizeQueryContextParams
);
}
}
}

View File

@ -33,6 +33,7 @@ public class ResourceType
public static final String CONFIG = "CONFIG";
public static final String STATE = "STATE";
public static final String SYSTEM_TABLE = "SYSTEM_TABLE";
public static final String QUERY_CONTEXT = "QUERY_CONTEXT";
private static final Set<String> KNOWN_TYPES = Sets.newConcurrentHashSet();
@ -43,6 +44,7 @@ public class ResourceType
registerResourceType(CONFIG);
registerResourceType(STATE);
registerResourceType(SYSTEM_TABLE);
registerResourceType(QUERY_CONTEXT);
}
/**

View File

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.guice.security;
import com.google.common.collect.ImmutableList;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Scopes;
import org.apache.druid.guice.JsonConfigProvider;
import org.apache.druid.guice.JsonConfigurator;
import org.apache.druid.guice.LazySingleton;
import org.apache.druid.jackson.JacksonModule;
import org.apache.druid.server.security.AuthConfig;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import javax.validation.Validation;
import javax.validation.Validator;
import java.util.Properties;
public class DruidAuthModuleTest
{
private Injector injector;
private DruidAuthModule authModule;
@Before
public void setup()
{
authModule = new DruidAuthModule();
injector = Guice.createInjector(
new JacksonModule(),
binder -> {
binder.bind(Validator.class).toInstance(Validation.buildDefaultValidatorFactory().getValidator());
binder.bindScope(LazySingleton.class, Scopes.SINGLETON);
},
authModule
);
}
@Test
public void testAuthConfigSingleton()
{
AuthConfig config1 = injector.getInstance(AuthConfig.class);
AuthConfig config2 = injector.getInstance(AuthConfig.class);
Assert.assertNotNull(config1);
Assert.assertSame(config1, config2);
}
@Test
public void testAuthConfigDefault()
{
Properties properties = new Properties();
final AuthConfig authConfig = injectProperties(properties);
Assert.assertNotNull(authConfig);
Assert.assertNull(authConfig.getAuthenticatorChain());
Assert.assertNull(authConfig.getAuthorizers());
Assert.assertTrue(authConfig.getUnsecuredPaths().isEmpty());
Assert.assertFalse(authConfig.isAllowUnauthenticatedHttpOptions());
Assert.assertFalse(authConfig.authorizeQueryContextParams());
}
@Test
public void testAuthConfigSet()
{
Properties properties = new Properties();
properties.setProperty("druid.auth.authenticatorChain", "[\"chain\", \"of\", \"authenticators\"]");
properties.setProperty("druid.auth.authorizers", "[\"authorizers\", \"list\"]");
properties.setProperty("druid.auth.unsecuredPaths", "[\"path1\", \"path2\"]");
properties.setProperty("druid.auth.allowUnauthenticatedHttpOptions", "true");
properties.setProperty("druid.auth.authorizeQueryContextParams", "true");
final AuthConfig authConfig = injectProperties(properties);
Assert.assertNotNull(authConfig);
Assert.assertEquals(ImmutableList.of("chain", "of", "authenticators"), authConfig.getAuthenticatorChain());
Assert.assertEquals(ImmutableList.of("authorizers", "list"), authConfig.getAuthorizers());
Assert.assertEquals(ImmutableList.of("path1", "path2"), authConfig.getUnsecuredPaths());
Assert.assertTrue(authConfig.authorizeQueryContextParams());
}
private AuthConfig injectProperties(Properties properties)
{
final JsonConfigProvider<AuthConfig> provider = JsonConfigProvider.of(
"druid.auth",
AuthConfig.class
);
provider.inject(properties, injector.getInstance(JsonConfigurator.class));
return provider.get().get();
}
}

View File

@ -37,20 +37,28 @@ import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.timeseries.TimeseriesQuery;
import org.apache.druid.server.log.RequestLogger;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.Authorizer;
import org.apache.druid.server.security.AuthorizerMapper;
import org.apache.druid.server.security.Resource;
import org.apache.druid.server.security.ResourceType;
import org.easymock.EasyMock;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import javax.servlet.http.HttpServletRequest;
public class QueryLifecycleTest
{
private static final String DATASOURCE = "some_datasource";
private static final String IDENTITY = "some_identity";
private static final String AUTHORIZER = "some_authorizer";
private final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
@ -64,6 +72,7 @@ public class QueryLifecycleTest
RequestLogger requestLogger;
AuthorizerMapper authzMapper;
DefaultQueryConfig queryConfig;
AuthConfig authConfig;
QueryLifecycle lifecycle;
@ -84,8 +93,10 @@ public class QueryLifecycleTest
metricsFactory = EasyMock.createMock(GenericQueryMetricsFactory.class);
emitter = EasyMock.createMock(ServiceEmitter.class);
requestLogger = EasyMock.createNiceMock(RequestLogger.class);
authzMapper = EasyMock.createMock(AuthorizerMapper.class);
authorizer = EasyMock.createMock(Authorizer.class);
authzMapper = new AuthorizerMapper(ImmutableMap.of(AUTHORIZER, authorizer));
queryConfig = EasyMock.createMock(DefaultQueryConfig.class);
authConfig = EasyMock.createMock(AuthConfig.class);
long nanos = System.nanoTime();
long millis = System.currentTimeMillis();
@ -97,6 +108,7 @@ public class QueryLifecycleTest
requestLogger,
authzMapper,
queryConfig,
authConfig,
millis,
nanos
);
@ -105,7 +117,6 @@ public class QueryLifecycleTest
runner = EasyMock.createMock(QueryRunner.class);
metrics = EasyMock.createNiceMock(QueryMetrics.class);
authenticationResult = EasyMock.createMock(AuthenticationResult.class);
authorizer = EasyMock.createMock(Authorizer.class);
}
@After
@ -117,7 +128,6 @@ public class QueryLifecycleTest
metricsFactory,
emitter,
requestLogger,
authzMapper,
queryConfig,
toolChest,
runner,
@ -166,6 +176,87 @@ public class QueryLifecycleTest
lifecycle.runSimple(query, authenticationResult, new Access(false));
}
@Test
public void testAuthorizeQueryContext_authorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(Access.OK);
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(ImmutableMap.of("foo", "bar", "baz", "qux"))
.build();
lifecycle.initialize(query);
Assert.assertEquals(
ImmutableMap.of("foo", "bar", "baz", "qux"),
lifecycle.getQuery().getQueryContext().getUserParams()
);
Assert.assertTrue(lifecycle.getQuery().getQueryContext().getMergedParams().containsKey("queryId"));
Assert.assertTrue(lifecycle.getQuery().getContext().containsKey("queryId"));
Assert.assertTrue(lifecycle.authorize(mockRequest()).isAllowed());
}
@Test
public void testAuthorizeQueryContext_notAuthorized()
{
EasyMock.expect(queryConfig.getContext()).andReturn(ImmutableMap.of()).anyTimes();
EasyMock.expect(authConfig.authorizeQueryContextParams()).andReturn(true).anyTimes();
EasyMock.expect(authenticationResult.getIdentity()).andReturn(IDENTITY).anyTimes();
EasyMock.expect(authenticationResult.getAuthorizerName()).andReturn(AUTHORIZER).anyTimes();
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource(DATASOURCE, ResourceType.DATASOURCE), Action.READ))
.andReturn(Access.OK);
EasyMock.expect(authorizer.authorize(authenticationResult, new Resource("foo", ResourceType.QUERY_CONTEXT), Action.WRITE))
.andReturn(new Access(false));
EasyMock.expect(toolChestWarehouse.getToolChest(EasyMock.anyObject()))
.andReturn(toolChest)
.once();
replayAll();
final TimeseriesQuery query = Druids.newTimeseriesQueryBuilder()
.dataSource(DATASOURCE)
.intervals(ImmutableList.of(Intervals.ETERNITY))
.aggregators(new CountAggregatorFactory("chocula"))
.context(ImmutableMap.of("foo", "bar"))
.build();
lifecycle.initialize(query);
Assert.assertFalse(lifecycle.authorize(mockRequest()).isAllowed());
}
private HttpServletRequest mockRequest()
{
HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
EasyMock.expect(request.getAttribute(EasyMock.eq(AuthConfig.DRUID_AUTHENTICATION_RESULT)))
.andReturn(authenticationResult).anyTimes();
EasyMock.expect(request.getAttribute(EasyMock.eq(AuthConfig.DRUID_ALLOW_UNSECURED_PATH)))
.andReturn(null).anyTimes();
EasyMock.expect(request.getAttribute(EasyMock.eq(AuthConfig.DRUID_AUTHORIZATION_CHECKED)))
.andReturn(null).anyTimes();
EasyMock.replay(request);
return request;
}
private void replayAll()
{
EasyMock.replay(
@ -174,8 +265,8 @@ public class QueryLifecycleTest
metricsFactory,
emitter,
requestLogger,
authzMapper,
queryConfig,
authConfig,
toolChest,
runner,
metrics,

View File

@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.server.security;
import nl.jqno.equalsverifier.EqualsVerifier;
import org.junit.Test;
public class AuthConfigTest
{
@Test
public void testEquals()
{
EqualsVerifier.configure().usingGetClass().forClass(AuthConfig.class).verify();
}
}

View File

@ -35,6 +35,7 @@ import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.emitter.service.ServiceMetricEvent;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException;
@ -43,6 +44,7 @@ import org.apache.druid.server.QueryStats;
import org.apache.druid.server.RequestLogLine;
import org.apache.druid.server.log.RequestLogger;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthorizationUtils;
import org.apache.druid.server.security.ForbiddenException;
@ -59,7 +61,6 @@ import org.apache.druid.sql.http.SqlQuery;
import javax.annotation.Nullable;
import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@ -74,7 +75,7 @@ import java.util.stream.Collectors;
* It ensures that a SQL query goes through the following stages, in the proper order:
*
* <ol>
* <li>Initialization ({@link #initialize(String, Map)})</li>
* <li>Initialization ({@link #initialize(String, QueryContext)})</li>
* <li>Validation and Authorization ({@link #validateAndAuthorize(HttpServletRequest)} or {@link #validateAndAuthorize(AuthenticationResult)})</li>
* <li>Planning ({@link #plan()})</li>
* <li>Execution ({@link #execute()})</li>
@ -91,6 +92,7 @@ public class SqlLifecycle
private final ServiceEmitter emitter;
private final RequestLogger requestLogger;
private final QueryScheduler queryScheduler;
private final AuthConfig authConfig;
private final long startMs;
private final long startNs;
@ -104,7 +106,7 @@ public class SqlLifecycle
// init during intialize
private String sql;
private Map<String, Object> queryContext;
private QueryContext queryContext;
private List<TypedValue> parameters;
// init during plan
private PlannerContext plannerContext;
@ -117,6 +119,7 @@ public class SqlLifecycle
ServiceEmitter emitter,
RequestLogger requestLogger,
QueryScheduler queryScheduler,
AuthConfig authConfig,
long startMs,
long startNs
)
@ -125,6 +128,7 @@ public class SqlLifecycle
this.emitter = emitter;
this.requestLogger = requestLogger;
this.queryScheduler = queryScheduler;
this.authConfig = authConfig;
this.startMs = startMs;
this.startNs = startNs;
this.parameters = Collections.emptyList();
@ -135,7 +139,7 @@ public class SqlLifecycle
*
* If successful (it will be), it will transition the lifecycle to {@link State#INITIALIZED}.
*/
public String initialize(String sql, Map<String, Object> queryContext)
public String initialize(String sql, QueryContext queryContext)
{
transition(State.NEW, State.INITIALIZED);
this.sql = sql;
@ -143,24 +147,21 @@ public class SqlLifecycle
return sqlQueryId();
}
private Map<String, Object> contextWithSqlId(Map<String, Object> queryContext)
private QueryContext contextWithSqlId(QueryContext queryContext)
{
Map<String, Object> newContext = new HashMap<>();
if (queryContext != null) {
newContext.putAll(queryContext);
}
// "bySegment" results are never valid to use with SQL because the result format is incompatible
// so, overwrite any user specified context to avoid exceptions down the line
if (newContext.remove(QueryContexts.BY_SEGMENT_KEY) != null) {
if (queryContext.removeUserParam(QueryContexts.BY_SEGMENT_KEY) != null) {
log.warn("'bySegment' results are not supported for SQL queries, ignoring query context parameter");
}
newContext.computeIfAbsent(PlannerContext.CTX_SQL_QUERY_ID, k -> UUID.randomUUID().toString());
return newContext;
queryContext.addDefaultParam(PlannerContext.CTX_SQL_QUERY_ID, UUID.randomUUID().toString());
return queryContext;
}
private String sqlQueryId()
{
return (String) this.queryContext.get(PlannerContext.CTX_SQL_QUERY_ID);
return queryContext.getAsString(PlannerContext.CTX_SQL_QUERY_ID);
}
/**
@ -230,7 +231,7 @@ public class SqlLifecycle
this.plannerContext.setAuthenticationResult(authenticationResult);
// set parameters on planner context, if parameters have already been set
this.plannerContext.setParameters(parameters);
this.validationResult = planner.validate();
this.validationResult = planner.validate(authConfig.authorizeQueryContextParams());
return validationResult;
}
// we can't collapse catch clauses since SqlPlanningException has type-sensitive constructors.
@ -346,7 +347,7 @@ public class SqlLifecycle
{
Sequence<Object[]> result;
initialize(sql, queryContext);
initialize(sql, new QueryContext(queryContext));
try {
setParameters(SqlQuery.getParameterList(parameters));
validateAndAuthorize(authenticationResult);
@ -417,7 +418,7 @@ public class SqlLifecycle
final long bytesWritten
)
{
if (sql == null) {
if (queryContext == null) {
// Never initialized, don't log or emit anything.
return;
}
@ -464,11 +465,12 @@ public class SqlLifecycle
statsMap.put("sqlQuery/time", TimeUnit.NANOSECONDS.toMillis(queryTimeNs));
statsMap.put("sqlQuery/bytes", bytesWritten);
statsMap.put("success", success);
statsMap.put("context", queryContext);
if (plannerContext != null) {
statsMap.put("identity", plannerContext.getAuthenticationResult().getIdentity());
queryContext.put("nativeQueryIds", plannerContext.getNativeQueryIds().toString());
queryContext.addSystemParam("nativeQueryIds", plannerContext.getNativeQueryIds().toString());
}
final Map<String, Object> context = queryContext.getMergedParams();
statsMap.put("context", context);
if (e != null) {
statsMap.put("exception", e.toString());
@ -481,7 +483,7 @@ public class SqlLifecycle
requestLogger.logSqlQuery(
RequestLogLine.forSql(
sql,
queryContext,
context,
DateTimes.utc(startMs),
remoteAddress,
new QueryStats(statsMap)
@ -502,7 +504,7 @@ public class SqlLifecycle
}
@VisibleForTesting
Map<String, Object> getQueryContext()
QueryContext getQueryContext()
{
return queryContext;
}

View File

@ -24,6 +24,7 @@ import org.apache.druid.guice.LazySingleton;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.server.QueryScheduler;
import org.apache.druid.server.log.RequestLogger;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
@LazySingleton
@ -33,19 +34,22 @@ public class SqlLifecycleFactory
private final ServiceEmitter emitter;
private final RequestLogger requestLogger;
private final QueryScheduler queryScheduler;
private final AuthConfig authConfig;
@Inject
public SqlLifecycleFactory(
PlannerFactory plannerFactory,
ServiceEmitter emitter,
RequestLogger requestLogger,
QueryScheduler queryScheduler
QueryScheduler queryScheduler,
AuthConfig authConfig
)
{
this.plannerFactory = plannerFactory;
this.emitter = emitter;
this.requestLogger = requestLogger;
this.queryScheduler = queryScheduler;
this.authConfig = authConfig;
}
public SqlLifecycle factorize()
@ -55,6 +59,7 @@ public class SqlLifecycleFactory
emitter,
requestLogger,
queryScheduler,
authConfig,
System.currentTimeMillis(),
System.nanoTime()
);

View File

@ -22,16 +22,13 @@ package org.apache.druid.sql.avatica;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.QueryContext;
import org.apache.druid.sql.SqlLifecycleFactory;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Future;
@ -44,13 +41,11 @@ import java.util.concurrent.atomic.AtomicReference;
public class DruidConnection
{
private static final Logger LOG = new Logger(DruidConnection.class);
private static final Set<String> SENSITIVE_CONTEXT_FIELDS = Sets.newHashSet(
"user", "password"
);
private final String connectionId;
private final int maxStatements;
private final ImmutableMap<String, Object> context;
private final ImmutableMap<String, Object> userSecret;
private final QueryContext context;
private final AtomicInteger statementCounter = new AtomicInteger();
private final AtomicReference<Future<?>> timeoutFuture = new AtomicReference<>();
@ -66,12 +61,14 @@ public class DruidConnection
public DruidConnection(
final String connectionId,
final int maxStatements,
final Map<String, Object> context
final Map<String, Object> userSecret,
final QueryContext context
)
{
this.connectionId = Preconditions.checkNotNull(connectionId);
this.maxStatements = maxStatements;
this.context = ImmutableMap.copyOf(context);
this.userSecret = ImmutableMap.copyOf(userSecret);
this.context = context;
this.statements = new ConcurrentHashMap<>();
}
@ -95,18 +92,11 @@ public class DruidConnection
throw DruidMeta.logFailure(new ISE("Too many open statements, limit is[%,d]", maxStatements));
}
// remove sensitive fields from the context, only the connection's context needs to have authentication
// credentials
Map<String, Object> sanitizedContext = Maps.filterEntries(
context,
e -> !SENSITIVE_CONTEXT_FIELDS.contains(e.getKey())
);
@SuppressWarnings("GuardedBy")
final DruidStatement statement = new DruidStatement(
connectionId,
statementId,
ImmutableSortedMap.copyOf(sanitizedContext),
context,
sqlLifecycleFactory.factorize(),
() -> {
// onClose function for the statement
@ -173,8 +163,8 @@ public class DruidConnection
return this;
}
public Map<String, Object> context()
public Map<String, Object> userSecret()
{
return context;
return userSecret;
}
}

View File

@ -40,6 +40,7 @@ import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.QueryContext;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.Authenticator;
import org.apache.druid.server.security.AuthenticatorMapper;
@ -53,9 +54,11 @@ import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executors;
@ -94,6 +97,9 @@ public class DruidMeta extends MetaImpl
}
private static final Logger LOG = new Logger(DruidMeta.class);
private static final Set<String> SENSITIVE_CONTEXT_FIELDS = ImmutableSet.of(
"user", "password"
);
private final SqlLifecycleFactory sqlLifecycleFactory;
private final ScheduledExecutorService exec;
@ -140,15 +146,21 @@ public class DruidMeta extends MetaImpl
{
try {
// Build connection context.
final ImmutableMap.Builder<String, Object> context = ImmutableMap.builder();
final Map<String, Object> secret = new HashMap<>();
final Map<String, Object> contextMap = new HashMap<>();
if (info != null) {
for (Map.Entry<String, String> entry : info.entrySet()) {
context.put(entry);
if (SENSITIVE_CONTEXT_FIELDS.contains(entry.getKey())) {
secret.put(entry.getKey(), entry.getValue());
} else {
contextMap.put(entry.getKey(), entry.getValue());
}
}
}
// we don't want to stringify arrays for JDBC ever because avatica needs to handle this
context.put(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false);
openDruidConnection(ch.id, context.build());
final QueryContext context = new QueryContext(contextMap);
context.addSystemParam(PlannerContext.CTX_SQL_STRINGIFY_ARRAYS, false);
openDruidConnection(ch.id, secret, context);
}
catch (NoSuchConnectionException e) {
throw e;
@ -697,7 +709,7 @@ public class DruidMeta extends MetaImpl
@Nullable
private AuthenticationResult authenticateConnection(final DruidConnection connection)
{
Map<String, Object> context = connection.context();
Map<String, Object> context = connection.userSecret();
for (Authenticator authenticator : authenticators) {
LOG.debug("Attempting authentication with authenticator[%s]", authenticator.getClass());
AuthenticationResult authenticationResult = authenticator.authenticateJDBCContext(context);
@ -714,7 +726,11 @@ public class DruidMeta extends MetaImpl
return null;
}
private DruidConnection openDruidConnection(final String connectionId, final Map<String, Object> context)
private DruidConnection openDruidConnection(
final String connectionId,
final Map<String, Object> userSecret,
final QueryContext context
)
{
if (connectionCount.incrementAndGet() > config.getMaxConnections()) {
// O(connections) but we don't expect this to happen often (it's a last-ditch effort to clear out
@ -744,7 +760,7 @@ public class DruidMeta extends MetaImpl
final DruidConnection putResult = connections.putIfAbsent(
connectionId,
new DruidConnection(connectionId, config.getMaxStatementsPerConnection(), context)
new DruidConnection(connectionId, config.getMaxStatementsPerConnection(), userSecret, context)
);
if (putResult != null) {

View File

@ -20,7 +20,6 @@
package org.apache.druid.sql.avatica;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import org.apache.calcite.avatica.AvaticaParameter;
import org.apache.calcite.avatica.ColumnMetaData;
@ -35,6 +34,7 @@ import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.query.QueryContext;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.ForbiddenException;
import org.apache.druid.sql.SqlLifecycle;
@ -46,7 +46,6 @@ import java.sql.Array;
import java.sql.DatabaseMetaData;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
/**
@ -57,7 +56,7 @@ public class DruidStatement implements Closeable
public static final long START_OFFSET = 0;
private final String connectionId;
private final int statementId;
private final Map<String, Object> queryContext;
private final QueryContext queryContext;
@GuardedBy("lock")
private final SqlLifecycle sqlLifecycle;
private final Runnable onClose;
@ -90,14 +89,14 @@ public class DruidStatement implements Closeable
public DruidStatement(
final String connectionId,
final int statementId,
final Map<String, Object> queryContext,
final QueryContext queryContext,
final SqlLifecycle sqlLifecycle,
final Runnable onClose
)
{
this.connectionId = Preconditions.checkNotNull(connectionId, "connectionId");
this.statementId = statementId;
this.queryContext = queryContext == null ? ImmutableMap.of() : queryContext;
this.queryContext = queryContext;
this.sqlLifecycle = Preconditions.checkNotNull(sqlLifecycle, "sqlLifecycle");
this.onClose = Preconditions.checkNotNull(onClose, "onClose");
this.yielderOpenCloseExecutor = Execs.singleThreaded(

View File

@ -77,7 +77,6 @@ import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.Resource;
@ -93,7 +92,6 @@ import org.apache.druid.sql.calcite.run.QueryMakerFactory;
import org.apache.druid.utils.Throwables;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.util.ArrayList;
import java.util.HashSet;
@ -131,7 +129,7 @@ public class DruidPlanner implements Closeable
*
* @return set of {@link Resource} corresponding to any Druid datasources or views which are taking part in the query.
*/
public ValidationResult validate() throws SqlParseException, ValidationException
public ValidationResult validate(boolean authorizeContextParams) throws SqlParseException, ValidationException
{
resetPlanner();
final ParsedNodes parsed = ParsedNodes.create(planner.parse(plannerContext.getSql()));
@ -154,6 +152,11 @@ public class DruidPlanner implements Closeable
final String targetDataSource = validateAndGetDataSourceForInsert(parsed.getInsertNode());
resourceActions.add(new ResourceAction(new Resource(targetDataSource, ResourceType.DATASOURCE), Action.WRITE));
}
if (authorizeContextParams) {
plannerContext.getQueryContext().getUserParams().keySet().forEach(contextParam -> resourceActions.add(
new ResourceAction(new Resource(contextParam, ResourceType.QUERY_CONTEXT), Action.WRITE)
));
}
plannerContext.setResourceActions(resourceActions);
return new ValidationResult(resourceActions);
@ -163,7 +166,7 @@ public class DruidPlanner implements Closeable
* Prepare an SQL query for execution, including some initial parsing and validation and any dynamic parameter type
* resolution, to support prepared statements via JDBC.
*
* In some future this could perhaps re-use some of the work done by {@link #validate()}
* In some future this could perhaps re-use some of the work done by {@link #validate(boolean)}
* instead of repeating it, but that day is not today.
*/
public PrepareResult prepare() throws SqlParseException, ValidationException, RelConversionException
@ -194,7 +197,7 @@ public class DruidPlanner implements Closeable
* Ideally, the query can be planned into a native Druid query, using {@link #planWithDruidConvention}, but will
* fall-back to {@link #planWithBindableConvention} if this is not possible.
*
* In some future this could perhaps re-use some of the work done by {@link #validate()}
* In some future this could perhaps re-use some of the work done by {@link #validate(boolean)}
* instead of repeating it, but that day is not today.
*/
public PlannerResult plan() throws SqlParseException, ValidationException, RelConversionException
@ -205,7 +208,7 @@ public class DruidPlanner implements Closeable
try {
if (parsed.getIngestionGranularity() != null) {
plannerContext.getQueryContext().put(
plannerContext.getQueryContext().addSystemParam(
DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY,
plannerContext.getJsonMapper().writeValueAsString(parsed.getIngestionGranularity())
);
@ -243,7 +246,7 @@ public class DruidPlanner implements Closeable
}
}
Logger logger = log;
if (!QueryContexts.isDebug(plannerContext.getQueryContext())) {
if (!plannerContext.getQueryContext().isDebug()) {
logger = log.noStackTrace();
}
String errorMessage = buildSQLPlanningErrorMessage(cannotPlanException);

View File

@ -20,13 +20,11 @@
package org.apache.druid.sql.calcite.planner;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.query.QueryContext;
import org.joda.time.DateTimeZone;
import org.joda.time.Period;
import java.util.Map;
import java.util.Objects;
public class PlannerConfig
@ -158,43 +156,37 @@ public class PlannerConfig
return useNativeQueryExplain;
}
public PlannerConfig withOverrides(final Map<String, Object> context)
public PlannerConfig withOverrides(final QueryContext queryContext)
{
if (context == null) {
if (queryContext.isEmpty()) {
return this;
}
final PlannerConfig newConfig = new PlannerConfig();
newConfig.metadataRefreshPeriod = getMetadataRefreshPeriod();
newConfig.maxTopNLimit = getMaxTopNLimit();
newConfig.useApproximateCountDistinct = getContextBoolean(
context,
newConfig.useApproximateCountDistinct = queryContext.getAsBoolean(
CTX_KEY_USE_APPROXIMATE_COUNT_DISTINCT,
isUseApproximateCountDistinct()
);
newConfig.useGroupingSetForExactDistinct = getContextBoolean(
context,
newConfig.useGroupingSetForExactDistinct = queryContext.getAsBoolean(
CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
isUseGroupingSetForExactDistinct()
);
newConfig.useApproximateTopN = getContextBoolean(
context,
newConfig.useApproximateTopN = queryContext.getAsBoolean(
CTX_KEY_USE_APPROXIMATE_TOPN,
isUseApproximateTopN()
);
newConfig.computeInnerJoinCostAsFilter = getContextBoolean(
context,
newConfig.computeInnerJoinCostAsFilter = queryContext.getAsBoolean(
CTX_COMPUTE_INNER_JOIN_COST_AS_FILTER,
computeInnerJoinCostAsFilter
);
newConfig.useNativeQueryExplain = getContextBoolean(
context,
newConfig.useNativeQueryExplain = queryContext.getAsBoolean(
CTX_KEY_USE_NATIVE_QUERY_EXPLAIN,
isUseNativeQueryExplain()
);
final int systemConfigMaxNumericInFilters = getMaxNumericInFilters();
final int queryContextMaxNumericInFilters = getContextInt(
context,
final int queryContextMaxNumericInFilters = queryContext.getAsInt(
CTX_MAX_NUMERIC_IN_FILTERS,
getMaxNumericInFilters()
);
@ -232,42 +224,6 @@ public class PlannerConfig
return queryContextMaxNumericInFilters;
}
private static int getContextInt(
final Map<String, Object> context,
final String parameter,
final int defaultValue
)
{
final Object value = context.get(parameter);
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Numbers.parseInt(value);
} else if (value instanceof Integer) {
return (Integer) value;
} else {
throw new IAE("Expected parameter[%s] to be integer", parameter);
}
}
private static boolean getContextBoolean(
final Map<String, Object> context,
final String parameter,
final boolean defaultValue
)
{
final Object value = context.get(parameter);
if (value == null) {
return defaultValue;
} else if (value instanceof String) {
return Boolean.parseBoolean((String) value);
} else if (value instanceof Boolean) {
return (Boolean) value;
} else {
throw new IAE("Expected parameter[%s] to be boolean", parameter);
}
}
@Override
public boolean equals(final Object o)
{

View File

@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.Numbers;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.BaseQuery;
import org.apache.druid.query.QueryContext;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.ResourceAction;
@ -45,9 +46,7 @@ import org.joda.time.DateTimeZone;
import org.joda.time.Interval;
import javax.annotation.Nullable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@ -80,7 +79,7 @@ public class PlannerContext
private final PlannerConfig plannerConfig;
private final DateTime localNow;
private final DruidSchemaCatalog rootSchema;
private final Map<String, Object> queryContext;
private final QueryContext queryContext;
private final String sqlQueryId;
private final boolean stringifyArrays;
private final CopyOnWriteArrayList<String> nativeQueryIds = new CopyOnWriteArrayList<>();
@ -107,7 +106,7 @@ public class PlannerContext
final DateTime localNow,
final boolean stringifyArrays,
final DruidSchemaCatalog rootSchema,
final Map<String, Object> queryContext
final QueryContext queryContext
)
{
this.sql = sql;
@ -116,7 +115,7 @@ public class PlannerContext
this.jsonMapper = jsonMapper;
this.plannerConfig = Preconditions.checkNotNull(plannerConfig, "plannerConfig");
this.rootSchema = rootSchema;
this.queryContext = queryContext != null ? new HashMap<>(queryContext) : new HashMap<>();
this.queryContext = queryContext;
this.localNow = Preconditions.checkNotNull(localNow, "localNow");
this.stringifyArrays = stringifyArrays;
@ -135,38 +134,32 @@ public class PlannerContext
final ObjectMapper jsonMapper,
final PlannerConfig plannerConfig,
final DruidSchemaCatalog rootSchema,
final Map<String, Object> queryContext
final QueryContext queryContext
)
{
final DateTime utcNow;
final DateTimeZone timeZone;
final boolean stringifyArrays;
if (queryContext != null) {
final Object stringifyParam = queryContext.get(CTX_SQL_STRINGIFY_ARRAYS);
final Object tsParam = queryContext.get(CTX_SQL_CURRENT_TIMESTAMP);
final Object tzParam = queryContext.get(CTX_SQL_TIME_ZONE);
final Object stringifyParam = queryContext.get(CTX_SQL_STRINGIFY_ARRAYS);
final Object tsParam = queryContext.get(CTX_SQL_CURRENT_TIMESTAMP);
final Object tzParam = queryContext.get(CTX_SQL_TIME_ZONE);
if (tsParam != null) {
utcNow = new DateTime(tsParam, DateTimeZone.UTC);
} else {
utcNow = new DateTime(DateTimeZone.UTC);
}
if (tzParam != null) {
timeZone = DateTimes.inferTzFromString(String.valueOf(tzParam));
} else {
timeZone = plannerConfig.getSqlTimeZone();
}
if (stringifyParam != null) {
stringifyArrays = Numbers.parseBoolean(stringifyParam);
} else {
stringifyArrays = true;
}
if (tsParam != null) {
utcNow = new DateTime(tsParam, DateTimeZone.UTC);
} else {
utcNow = new DateTime(DateTimeZone.UTC);
}
if (tzParam != null) {
timeZone = DateTimes.inferTzFromString(String.valueOf(tzParam));
} else {
timeZone = plannerConfig.getSqlTimeZone();
}
if (stringifyParam != null) {
stringifyArrays = Numbers.parseBoolean(stringifyParam);
} else {
stringifyArrays = true;
}
@ -219,7 +212,7 @@ public class PlannerContext
return rootSchema.getResourceType(schema, resourceName);
}
public Map<String, Object> getQueryContext()
public QueryContext getQueryContext()
{
return queryContext;
}

View File

@ -38,6 +38,7 @@ import org.apache.calcite.tools.Frameworks;
import org.apache.calcite.tools.ValidationException;
import org.apache.druid.guice.annotations.Json;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.server.security.Access;
import org.apache.druid.server.security.AuthorizerMapper;
@ -96,7 +97,7 @@ public class PlannerFactory
/**
* Create a Druid query planner from an initial query context
*/
public DruidPlanner createPlanner(final String sql, final Map<String, Object> queryContext)
public DruidPlanner createPlanner(final String sql, final QueryContext queryContext)
{
final PlannerContext context = PlannerContext.create(
sql,
@ -126,11 +127,11 @@ public class PlannerFactory
@VisibleForTesting
public DruidPlanner createPlannerForTesting(final Map<String, Object> queryContext, String query)
{
final DruidPlanner thePlanner = createPlanner(query, queryContext);
final DruidPlanner thePlanner = createPlanner(query, new QueryContext(queryContext));
thePlanner.getPlannerContext()
.setAuthenticationResult(NoopEscalator.getInstance().createEscalatedAuthenticationResult());
try {
thePlanner.validate();
thePlanner.validate(false);
}
catch (SqlParseException | ValidationException e) {
throw new RuntimeException(e);
@ -151,7 +152,9 @@ public class PlannerFactory
.withExpand(false)
.withDecorrelationEnabled(false)
.withTrimUnusedFields(false)
.withInSubQueryThreshold(QueryContexts.getInSubQueryThreshold(plannerContext.getQueryContext()))
.withInSubQueryThreshold(
QueryContexts.getInSubQueryThreshold(plannerContext.getQueryContext().getMergedParams())
)
.build();
return Frameworks
.newConfigBuilder()

View File

@ -905,7 +905,7 @@ public class DruidQuery
if (!Granularities.ALL.equals(queryGranularity) || grouping.hasGroupingDimensionsDropped()) {
theContext.put(TimeseriesQuery.SKIP_EMPTY_BUCKETS, true);
}
theContext.putAll(plannerContext.getQueryContext());
theContext.putAll(plannerContext.getQueryContext().getMergedParams());
final Pair<DataSource, Filtration> dataSourceFiltrationPair = getFiltration(
dataSource,
@ -1025,7 +1025,7 @@ public class DruidQuery
Granularities.ALL,
grouping.getAggregatorFactories(),
postAggregators,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
ImmutableSortedMap.copyOf(plannerContext.getQueryContext().getMergedParams())
);
}
@ -1082,7 +1082,7 @@ public class DruidQuery
havingSpec,
Optional.ofNullable(sorting).orElse(Sorting.none()).limitSpec(),
grouping.getSubtotals().toSubtotalsSpec(grouping.getDimensionSpecs()),
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
ImmutableSortedMap.copyOf(plannerContext.getQueryContext().getMergedParams())
);
// We don't apply timestamp computation optimization yet when limit is pushed down. Maybe someday.
if (query.getLimitSpec() instanceof DefaultLimitSpec && query.isApplyLimitPushDown()) {
@ -1237,7 +1237,7 @@ public class DruidQuery
filtration.getDimFilter(),
ImmutableList.copyOf(scanColumns),
false,
ImmutableSortedMap.copyOf(plannerContext.getQueryContext())
ImmutableSortedMap.copyOf(plannerContext.getQueryContext().getMergedParams())
);
}
}

View File

@ -43,7 +43,6 @@ import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.rel.DruidJoinQueryRel;
import org.apache.druid.sql.calcite.rel.DruidQueryRel;
@ -74,7 +73,7 @@ public class DruidJoinRule extends RelOptRule
operand(DruidRel.class, any())
)
);
this.enableLeftScanDirect = QueryContexts.getEnableJoinLeftScanDirect(plannerContext.getQueryContext());
this.enableLeftScanDirect = plannerContext.getQueryContext().isEnableJoinLeftScanDirect();
this.plannerContext = plannerContext;
}

View File

@ -28,6 +28,7 @@ import org.apache.calcite.schema.FunctionParameter;
import org.apache.calcite.schema.TableMacro;
import org.apache.calcite.schema.TranslatableTable;
import org.apache.calcite.schema.impl.ViewTable;
import org.apache.druid.query.QueryContext;
import org.apache.druid.sql.calcite.planner.DruidPlanner;
import org.apache.druid.sql.calcite.planner.PlannerFactory;
import org.apache.druid.sql.calcite.schema.DruidSchemaName;
@ -56,7 +57,7 @@ public class DruidViewMacro implements TableMacro
public TranslatableTable apply(final List<Object> arguments)
{
final RelDataType rowType;
try (final DruidPlanner planner = plannerFactory.createPlanner(viewSql, null)) {
try (final DruidPlanner planner = plannerFactory.createPlanner(viewSql, new QueryContext())) {
rowType = planner.plan().rowType();
}
catch (Exception e) {

View File

@ -34,6 +34,7 @@ import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.query.BadQueryException;
import org.apache.druid.query.QueryCapacityExceededException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryInterruptedException;
import org.apache.druid.query.QueryTimeoutException;
import org.apache.druid.query.QueryUnsupportedException;
@ -63,7 +64,6 @@ import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import javax.ws.rs.core.Response.Status;
import javax.ws.rs.core.StreamingOutput;
import java.io.IOException;
import java.util.List;
import java.util.Set;
@ -108,7 +108,7 @@ public class SqlResource
) throws IOException
{
final SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
final String sqlQueryId = lifecycle.initialize(sqlQuery.getQuery(), sqlQuery.getContext());
final String sqlQueryId = lifecycle.initialize(sqlQuery.getQuery(), new QueryContext(sqlQuery.getContext()));
final String remoteAddr = req.getRemoteAddr();
final String currThreadName = Thread.currentThread().getName();

View File

@ -29,6 +29,7 @@ import org.apache.calcite.tools.ValidationException;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.emitter.service.ServiceEventBuilder;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.log.RequestLogger;
@ -71,7 +72,8 @@ public class SqlLifecycleTest
plannerFactory,
serviceEmitter,
requestLogger,
QueryStackTests.DEFAULT_NOOP_SCHEDULER
QueryStackTests.DEFAULT_NOOP_SCHEDULER,
new AuthConfig()
);
}
@ -81,11 +83,11 @@ public class SqlLifecycleTest
SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
final String sql = "select 1 + ?";
final Map<String, Object> queryContext = ImmutableMap.of(QueryContexts.BY_SEGMENT_KEY, "true");
lifecycle.initialize(sql, queryContext);
lifecycle.initialize(sql, new QueryContext(queryContext));
Assert.assertEquals(SqlLifecycle.State.INITIALIZED, lifecycle.getState());
Assert.assertEquals(1, lifecycle.getQueryContext().size());
Assert.assertEquals(1, lifecycle.getQueryContext().getMergedParams().size());
// should contain only query id, not bySegment since it is not valid for SQL
Assert.assertTrue(lifecycle.getQueryContext().containsKey(PlannerContext.CTX_SQL_QUERY_ID));
Assert.assertTrue(lifecycle.getQueryContext().getMergedParams().containsKey(PlannerContext.CTX_SQL_QUERY_ID));
}
@Test
@ -94,11 +96,10 @@ public class SqlLifecycleTest
{
SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
final String sql = "select 1 + ?";
final Map<String, Object> queryContext = Collections.emptyMap();
Assert.assertEquals(SqlLifecycle.State.NEW, lifecycle.getState());
// test initialize
lifecycle.initialize(sql, queryContext);
lifecycle.initialize(sql, new QueryContext());
Assert.assertEquals(SqlLifecycle.State.INITIALIZED, lifecycle.getState());
List<TypedValue> parameters = ImmutableList.of(new SqlParameter(SqlType.BIGINT, 1L).getTypedValue());
lifecycle.setParameters(parameters);
@ -118,7 +119,7 @@ public class SqlLifecycleTest
EasyMock.expect(plannerFactory.getAuthorizerMapper()).andReturn(CalciteTests.TEST_AUTHORIZER_MAPPER).once();
mockPlannerContext.setAuthorizationResult(Access.OK);
EasyMock.expectLastCall();
EasyMock.expect(mockPlanner.validate()).andReturn(validationResult).once();
EasyMock.expect(mockPlanner.validate(false)).andReturn(validationResult).once();
mockPlanner.close();
EasyMock.expectLastCall();
@ -191,11 +192,10 @@ public class SqlLifecycleTest
// is run
SqlLifecycle lifecycle = sqlLifecycleFactory.factorize();
final String sql = "select 1 + ?";
final Map<String, Object> queryContext = Collections.emptyMap();
Assert.assertEquals(SqlLifecycle.State.NEW, lifecycle.getState());
// test initialize
lifecycle.initialize(sql, queryContext);
lifecycle.initialize(sql, new QueryContext());
Assert.assertEquals(SqlLifecycle.State.INITIALIZED, lifecycle.getState());
List<TypedValue> parameters = ImmutableList.of(new SqlParameter(SqlType.BIGINT, 1L).getTypedValue());
lifecycle.setParameters(parameters);
@ -215,7 +215,7 @@ public class SqlLifecycleTest
EasyMock.expect(plannerFactory.getAuthorizerMapper()).andReturn(CalciteTests.TEST_AUTHORIZER_MAPPER).once();
mockPlannerContext.setAuthorizationResult(Access.OK);
EasyMock.expectLastCall();
EasyMock.expect(mockPlanner.validate()).andReturn(validationResult).once();
EasyMock.expect(mockPlanner.validate(false)).andReturn(validationResult).once();
mockPlanner.close();
EasyMock.expectLastCall();

View File

@ -27,6 +27,7 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AllowAllAuthenticator;
@ -113,147 +114,155 @@ public class DruidStatementTest extends CalciteTestBase
public void testSignature()
{
final String sql = "SELECT * FROM druid.foo";
final DruidStatement statement = new DruidStatement("", 0, null, sqlLifecycleFactory.factorize(), () -> {
}).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT);
// Check signature.
final Meta.Signature signature = statement.getSignature();
Assert.assertEquals(Meta.CursorFactory.ARRAY, signature.cursorFactory);
Assert.assertEquals(Meta.StatementType.SELECT, signature.statementType);
Assert.assertEquals(sql, signature.sql);
Assert.assertEquals(
Lists.newArrayList(
Lists.newArrayList("__time", "TIMESTAMP", "java.lang.Long"),
Lists.newArrayList("cnt", "BIGINT", "java.lang.Number"),
Lists.newArrayList("dim1", "VARCHAR", "java.lang.String"),
Lists.newArrayList("dim2", "VARCHAR", "java.lang.String"),
Lists.newArrayList("dim3", "VARCHAR", "java.lang.String"),
Lists.newArrayList("m1", "FLOAT", "java.lang.Float"),
Lists.newArrayList("m2", "DOUBLE", "java.lang.Double"),
Lists.newArrayList("unique_dim1", "OTHER", "java.lang.Object")
),
Lists.transform(
signature.columns,
new Function<ColumnMetaData, List<String>>()
{
@Override
public List<String> apply(final ColumnMetaData columnMetaData)
try (final DruidStatement statement = statement(sql)) {
// Check signature.
final Meta.Signature signature = statement.getSignature();
Assert.assertEquals(Meta.CursorFactory.ARRAY, signature.cursorFactory);
Assert.assertEquals(Meta.StatementType.SELECT, signature.statementType);
Assert.assertEquals(sql, signature.sql);
Assert.assertEquals(
Lists.newArrayList(
Lists.newArrayList("__time", "TIMESTAMP", "java.lang.Long"),
Lists.newArrayList("cnt", "BIGINT", "java.lang.Number"),
Lists.newArrayList("dim1", "VARCHAR", "java.lang.String"),
Lists.newArrayList("dim2", "VARCHAR", "java.lang.String"),
Lists.newArrayList("dim3", "VARCHAR", "java.lang.String"),
Lists.newArrayList("m1", "FLOAT", "java.lang.Float"),
Lists.newArrayList("m2", "DOUBLE", "java.lang.Double"),
Lists.newArrayList("unique_dim1", "OTHER", "java.lang.Object")
),
Lists.transform(
signature.columns,
new Function<ColumnMetaData, List<String>>()
{
return Lists.newArrayList(
columnMetaData.label,
columnMetaData.type.name,
columnMetaData.type.rep.clazz.getName()
);
@Override
public List<String> apply(final ColumnMetaData columnMetaData)
{
return Lists.newArrayList(
columnMetaData.label,
columnMetaData.type.name,
columnMetaData.type.rep.clazz.getName()
);
}
}
}
)
);
)
);
}
}
@Test
public void testSubQueryWithOrderBy()
{
final String sql = "select T20.F13 as F22 from (SELECT DISTINCT dim1 as F13 FROM druid.foo T10) T20 order by T20.F13 ASC";
final DruidStatement statement = new DruidStatement("", 0, null, sqlLifecycleFactory.factorize(), () -> {
}).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT);
// First frame, ask for all rows.
Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 6);
Assert.assertEquals(
Meta.Frame.create(
0,
true,
Lists.newArrayList(
new Object[]{""},
new Object[]{
"1"
},
new Object[]{"10.1"},
new Object[]{"2"},
new Object[]{"abc"},
new Object[]{"def"}
)
),
frame
);
Assert.assertTrue(statement.isDone());
try (final DruidStatement statement = statement(sql)) {
// First frame, ask for all rows.
Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 6);
Assert.assertEquals(
Meta.Frame.create(
0,
true,
Lists.newArrayList(
new Object[]{""},
new Object[]{
"1"
},
new Object[]{"10.1"},
new Object[]{"2"},
new Object[]{"abc"},
new Object[]{"def"}
)
),
frame
);
Assert.assertTrue(statement.isDone());
}
}
@Test
public void testSelectAllInFirstFrame()
{
final String sql = "SELECT __time, cnt, dim1, dim2, m1 FROM druid.foo";
final DruidStatement statement = new DruidStatement("", 0, null, sqlLifecycleFactory.factorize(), () -> {
}).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT);
// First frame, ask for all rows.
Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 6);
Assert.assertEquals(
Meta.Frame.create(
0,
true,
Lists.newArrayList(
new Object[]{DateTimes.of("2000-01-01").getMillis(), 1L, "", "a", 1.0f},
new Object[]{
DateTimes.of("2000-01-02").getMillis(),
1L,
"10.1",
NullHandling.defaultStringValue(),
2.0f
},
new Object[]{DateTimes.of("2000-01-03").getMillis(), 1L, "2", "", 3.0f},
new Object[]{DateTimes.of("2001-01-01").getMillis(), 1L, "1", "a", 4.0f},
new Object[]{DateTimes.of("2001-01-02").getMillis(), 1L, "def", "abc", 5.0f},
new Object[]{DateTimes.of("2001-01-03").getMillis(), 1L, "abc", NullHandling.defaultStringValue(), 6.0f}
)
),
frame
);
Assert.assertTrue(statement.isDone());
try (final DruidStatement statement = statement(sql)) {
// First frame, ask for all rows.
Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 6);
Assert.assertEquals(
Meta.Frame.create(
0,
true,
Lists.newArrayList(
new Object[]{DateTimes.of("2000-01-01").getMillis(), 1L, "", "a", 1.0f},
new Object[]{
DateTimes.of("2000-01-02").getMillis(),
1L,
"10.1",
NullHandling.defaultStringValue(),
2.0f
},
new Object[]{DateTimes.of("2000-01-03").getMillis(), 1L, "2", "", 3.0f},
new Object[]{DateTimes.of("2001-01-01").getMillis(), 1L, "1", "a", 4.0f},
new Object[]{DateTimes.of("2001-01-02").getMillis(), 1L, "def", "abc", 5.0f},
new Object[]{DateTimes.of("2001-01-03").getMillis(), 1L, "abc", NullHandling.defaultStringValue(), 6.0f}
)
),
frame
);
Assert.assertTrue(statement.isDone());
}
}
@Test
public void testSelectSplitOverTwoFrames()
{
final String sql = "SELECT __time, cnt, dim1, dim2, m1 FROM druid.foo";
final DruidStatement statement = new DruidStatement("", 0, null, sqlLifecycleFactory.factorize(), () -> {
}).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT);
try (final DruidStatement statement = statement(sql)) {
// First frame, ask for 2 rows.
Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 2);
Assert.assertEquals(
Meta.Frame.create(
0,
false,
Lists.newArrayList(
new Object[]{DateTimes.of("2000-01-01").getMillis(), 1L, "", "a", 1.0f},
new Object[]{
DateTimes.of("2000-01-02").getMillis(),
1L,
"10.1",
NullHandling.defaultStringValue(),
2.0f
}
)
),
frame
);
Assert.assertFalse(statement.isDone());
// First frame, ask for 2 rows.
Meta.Frame frame = statement.execute(Collections.emptyList()).nextFrame(DruidStatement.START_OFFSET, 2);
Assert.assertEquals(
Meta.Frame.create(
0,
false,
Lists.newArrayList(
new Object[]{DateTimes.of("2000-01-01").getMillis(), 1L, "", "a", 1.0f},
new Object[]{
DateTimes.of("2000-01-02").getMillis(),
1L,
"10.1",
NullHandling.defaultStringValue(),
2.0f
}
)
),
frame
);
Assert.assertFalse(statement.isDone());
// Last frame, ask for all remaining rows.
frame = statement.nextFrame(2, 10);
Assert.assertEquals(
Meta.Frame.create(
2,
true,
Lists.newArrayList(
new Object[]{DateTimes.of("2000-01-03").getMillis(), 1L, "2", "", 3.0f},
new Object[]{DateTimes.of("2001-01-01").getMillis(), 1L, "1", "a", 4.0f},
new Object[]{DateTimes.of("2001-01-02").getMillis(), 1L, "def", "abc", 5.0f},
new Object[]{DateTimes.of("2001-01-03").getMillis(), 1L, "abc", NullHandling.defaultStringValue(), 6.0f}
)
),
frame
);
Assert.assertTrue(statement.isDone());
}
}
// Last frame, ask for all remaining rows.
frame = statement.nextFrame(2, 10);
Assert.assertEquals(
Meta.Frame.create(
2,
true,
Lists.newArrayList(
new Object[]{DateTimes.of("2000-01-03").getMillis(), 1L, "2", "", 3.0f},
new Object[]{DateTimes.of("2001-01-01").getMillis(), 1L, "1", "a", 4.0f},
new Object[]{DateTimes.of("2001-01-02").getMillis(), 1L, "def", "abc", 5.0f},
new Object[]{DateTimes.of("2001-01-03").getMillis(), 1L, "abc", NullHandling.defaultStringValue(), 6.0f}
)
),
frame
);
Assert.assertTrue(statement.isDone());
private DruidStatement statement(String sql)
{
return new DruidStatement(
"",
0,
new QueryContext(),
sqlLifecycleFactory.factorize(),
() -> {}
).prepare(sql, -1, AllowAllAuthenticator.ALLOW_ALL_RESULT);
}
}

View File

@ -42,6 +42,7 @@ import org.apache.druid.query.DataSource;
import org.apache.druid.query.Druids;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
@ -74,6 +75,7 @@ import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.join.JoinType;
import org.apache.druid.segment.virtual.ExpressionVirtualColumn;
import org.apache.druid.server.QueryStackTests;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.AuthorizerMapper;
import org.apache.druid.server.security.ForbiddenException;
@ -848,6 +850,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase
{
final SqlLifecycleFactory sqlLifecycleFactory = getSqlLifecycleFactory(
plannerConfig,
new AuthConfig(),
operatorTable,
macroTable,
authorizerMapper,
@ -961,9 +964,21 @@ public class BaseCalciteQueryTest extends CalciteTestBase
String sql,
AuthenticationResult authenticationResult
)
{
return analyzeResources(plannerConfig, new AuthConfig(), sql, ImmutableMap.of(), authenticationResult);
}
public Set<ResourceAction> analyzeResources(
PlannerConfig plannerConfig,
AuthConfig authConfig,
String sql,
Map<String, Object> contexts,
AuthenticationResult authenticationResult
)
{
SqlLifecycleFactory lifecycleFactory = getSqlLifecycleFactory(
plannerConfig,
authConfig,
createOperatorTable(),
createMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
@ -971,12 +986,13 @@ public class BaseCalciteQueryTest extends CalciteTestBase
);
SqlLifecycle lifecycle = lifecycleFactory.factorize();
lifecycle.initialize(sql, ImmutableMap.of());
lifecycle.initialize(sql, new QueryContext(contexts));
return lifecycle.runAnalyzeResources(authenticationResult).getResourceActions();
}
public SqlLifecycleFactory getSqlLifecycleFactory(
PlannerConfig plannerConfig,
AuthConfig authConfig,
DruidOperatorTable operatorTable,
ExprMacroTable macroTable,
AuthorizerMapper authorizerMapper,
@ -1006,7 +1022,7 @@ public class BaseCalciteQueryTest extends CalciteTestBase
objectMapper,
CalciteTests.DRUID_SCHEMA_NAME
);
final SqlLifecycleFactory sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(plannerFactory);
final SqlLifecycleFactory sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(plannerFactory, authConfig);
viewManager.createView(
plannerFactory,

View File

@ -32,6 +32,7 @@ import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.query.Query;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFactory;
@ -41,6 +42,7 @@ import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.AuthenticationResult;
import org.apache.druid.server.security.ForbiddenException;
import org.apache.druid.server.security.Resource;
@ -908,6 +910,7 @@ public class CalciteInsertDmlTest extends BaseCalciteQueryTest
final SqlLifecycleFactory sqlLifecycleFactory = getSqlLifecycleFactory(
plannerConfig,
new AuthConfig(),
createOperatorTable(),
createMacroTable(),
CalciteTests.TEST_AUTHORIZER_MAPPER,
@ -915,7 +918,7 @@ public class CalciteInsertDmlTest extends BaseCalciteQueryTest
);
final SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
sqlLifecycle.initialize(sql, queryContext);
sqlLifecycle.initialize(sql, new QueryContext(queryContext));
final Throwable e = Assert.assertThrows(
Throwable.class,

View File

@ -38,6 +38,7 @@ import org.apache.druid.query.Druids;
import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.LookupDataSource;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.ResourceLimitExceededException;
@ -132,8 +133,6 @@ import java.util.stream.Collectors;
public class CalciteQueryTest extends BaseCalciteQueryTest
{
@Test
public void testGroupByWithPostAggregatorReferencingTimeFloorColumnOnTimeseries() throws Exception
{
@ -2162,10 +2161,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
// to a bug in the Calcite's rule (AggregateExpandDistinctAggregatesRule)
try {
testQuery(
PLANNER_CONFIG_NO_HLL.withOverrides(ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"false"
)), // Enable exact count distinct
PLANNER_CONFIG_NO_HLL.withOverrides(
new QueryContext(
ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"false"
)
)
), // Enable exact count distinct
sqlQuery,
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(),
@ -2179,10 +2182,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
requireMergeBuffers(3);
testQuery(
PLANNER_CONFIG_NO_HLL.withOverrides(ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"true"
)),
PLANNER_CONFIG_NO_HLL.withOverrides(
new QueryContext(
ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"true"
)
)
),
sqlQuery,
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(
@ -6243,10 +6250,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
{
requireMergeBuffers(4);
testQuery(
PLANNER_CONFIG_NO_HLL.withOverrides(ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"true"
)),
PLANNER_CONFIG_NO_HLL.withOverrides(
new QueryContext(
ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"true"
)
)
),
"SELECT FLOOR(__time to day), COUNT(distinct city), COUNT(distinct user) FROM druid.visits GROUP BY 1",
CalciteTests.REGULAR_USER_AUTH_RESULT,
ImmutableList.of(
@ -7008,7 +7019,7 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
+ " )\n"
+ ")";
final String legacyExplanation =
"DruidOuterQueryRel(query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"query\",\"query\":{\"queryType\":\"scan\",\"dataSource\":{\"type\":\"table\",\"name\":\"__subquery__\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"resultFormat\":\"list\",\"batchSize\":20480,\"filter\":null,\"context\":null,\"descending\":false,\"granularity\":{\"type\":\"all\"}}},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"dimensions\":[],\"aggregations\":[{\"type\":\"count\",\"name\":\"a0\"}],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\",\"sqlQueryId\":\"dummy\",\"vectorize\":\"false\",\"vectorizeVirtualColumns\":\"false\"},\"descending\":false}], signature=[{a0:LONG}])\n"
"DruidOuterQueryRel(query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"query\",\"query\":{\"queryType\":\"scan\",\"dataSource\":{\"type\":\"table\",\"name\":\"__subquery__\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"resultFormat\":\"list\",\"batchSize\":20480,\"filter\":null,\"context\":{},\"descending\":false,\"granularity\":{\"type\":\"all\"}}},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"dimensions\":[],\"aggregations\":[{\"type\":\"count\",\"name\":\"a0\"}],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\",\"sqlQueryId\":\"dummy\",\"vectorize\":\"false\",\"vectorizeVirtualColumns\":\"false\"},\"descending\":false}], signature=[{a0:LONG}])\n"
+ " DruidJoinQueryRel(condition=[=(SUBSTRING($3, 1, 1), $8)], joinType=[inner], query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"table\",\"name\":\"__join__\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":null,\"granularity\":{\"type\":\"all\"},\"dimensions\":[{\"type\":\"default\",\"dimension\":\"dim2\",\"outputName\":\"d0\",\"outputType\":\"STRING\"}],\"aggregations\":[],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\",\"sqlQueryId\":\"dummy\",\"vectorize\":\"false\",\"vectorizeVirtualColumns\":\"false\"},\"descending\":false}], signature=[{d0:STRING}])\n"
+ " DruidQueryRel(query=[{\"queryType\":\"scan\",\"dataSource\":{\"type\":\"table\",\"name\":\"foo\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"resultFormat\":\"compactedList\",\"batchSize\":20480,\"filter\":null,\"columns\":[\"__time\",\"cnt\",\"dim1\",\"dim2\",\"dim3\",\"m1\",\"m2\",\"unique_dim1\"],\"legacy\":false,\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\",\"sqlQueryId\":\"dummy\",\"vectorize\":\"false\",\"vectorizeVirtualColumns\":\"false\"},\"descending\":false,\"granularity\":{\"type\":\"all\"}}], signature=[{__time:LONG, cnt:LONG, dim1:STRING, dim2:STRING, dim3:STRING, m1:FLOAT, m2:DOUBLE, unique_dim1:COMPLEX<hyperUnique>}])\n"
+ " DruidQueryRel(query=[{\"queryType\":\"groupBy\",\"dataSource\":{\"type\":\"table\",\"name\":\"foo\"},\"intervals\":{\"type\":\"intervals\",\"intervals\":[\"-146136543-09-08T08:23:32.096Z/146140482-04-24T15:36:27.903Z\"]},\"virtualColumns\":[],\"filter\":{\"type\":\"not\",\"field\":{\"type\":\"selector\",\"dimension\":\"dim1\",\"value\":null,\"extractionFn\":null}},\"granularity\":{\"type\":\"all\"},\"dimensions\":[{\"type\":\"extraction\",\"dimension\":\"dim1\",\"outputName\":\"d0\",\"outputType\":\"STRING\",\"extractionFn\":{\"type\":\"substring\",\"index\":0,\"length\":1}}],\"aggregations\":[],\"postAggregations\":[],\"having\":null,\"limitSpec\":{\"type\":\"NoopLimitSpec\"},\"context\":{\"defaultTimeout\":300000,\"maxScatterGatherBytes\":9223372036854775807,\"sqlCurrentTimestamp\":\"2000-01-01T00:00:00Z\",\"sqlQueryId\":\"dummy\",\"vectorize\":\"false\",\"vectorizeVirtualColumns\":\"false\"},\"descending\":false}], signature=[{d0:STRING}])\n";
@ -8269,10 +8280,14 @@ public class CalciteQueryTest extends BaseCalciteQueryTest
cannotVectorize();
requireMergeBuffers(3);
testQuery(
PLANNER_CONFIG_NO_HLL.withOverrides(ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"true"
)),
PLANNER_CONFIG_NO_HLL.withOverrides(
new QueryContext(
ImmutableMap.of(
PlannerConfig.CTX_KEY_USE_GROUPING_SET_FOR_EXACT_DISTINCT,
"true"
)
)
),
"SELECT\n"
+ "(SUM(CASE WHEN (TIMESTAMP '2000-01-04 17:00:00'<=__time AND __time<TIMESTAMP '2022-01-05 17:00:00') THEN 1 ELSE 0 END)*1.0/COUNT(DISTINCT CASE WHEN (TIMESTAMP '2000-01-04 17:00:00'<=__time AND __time<TIMESTAMP '2022-01-05 17:00:00') THEN dim1 END))\n"
+ "FROM druid.foo\n"

View File

@ -19,8 +19,10 @@
package org.apache.druid.sql.calcite;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.server.security.Action;
import org.apache.druid.server.security.AuthConfig;
import org.apache.druid.server.security.Resource;
import org.apache.druid.server.security.ResourceAction;
import org.apache.druid.server.security.ResourceType;
@ -29,6 +31,8 @@ import org.apache.druid.sql.calcite.util.CalciteTests;
import org.junit.Assert;
import org.junit.Test;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class DruidPlannerResourceAnalyzeTest extends BaseCalciteQueryTest
@ -264,21 +268,91 @@ public class DruidPlannerResourceAnalyzeTest extends BaseCalciteQueryTest
}
private void testSysTable(String sql, String name, PlannerConfig plannerConfig)
{
testSysTable(sql, name, ImmutableMap.of(), plannerConfig, new AuthConfig());
}
private void testSysTable(
String sql,
String name,
Map<String, Object> context,
PlannerConfig plannerConfig,
AuthConfig authConfig
)
{
Set<ResourceAction> requiredResources = analyzeResources(
plannerConfig,
authConfig,
sql,
context,
CalciteTests.REGULAR_USER_AUTH_RESULT
);
if (name == null) {
Assert.assertEquals(0, requiredResources.size());
} else {
Assert.assertEquals(
ImmutableSet.of(
new ResourceAction(new Resource(name, ResourceType.SYSTEM_TABLE), Action.READ)
),
requiredResources
);
final Set<ResourceAction> expectedResources = new HashSet<>();
if (name != null) {
expectedResources.add(new ResourceAction(new Resource(name, ResourceType.SYSTEM_TABLE), Action.READ));
}
if (context != null && !context.isEmpty()) {
context.forEach((k, v) -> expectedResources.add(
new ResourceAction(new Resource(k, ResourceType.QUERY_CONTEXT), Action.WRITE)
));
}
Assert.assertEquals(expectedResources, requiredResources);
}
@Test
public void testSysTableWithQueryContext()
{
final AuthConfig authConfig = AuthConfig.newBuilder().setAuthorizeQueryContextParams(true).build();
final Map<String, Object> context = ImmutableMap.of(
"baz", "fo",
"nested-bar", ImmutableMap.of("nested-key", "nested-val")
);
testSysTable("SELECT * FROM sys.segments", null, context, PLANNER_CONFIG_DEFAULT, authConfig);
testSysTable("SELECT * FROM sys.servers", null, context, PLANNER_CONFIG_DEFAULT, authConfig);
testSysTable("SELECT * FROM sys.server_segments", null, context, PLANNER_CONFIG_DEFAULT, authConfig);
testSysTable("SELECT * FROM sys.tasks", null, context, PLANNER_CONFIG_DEFAULT, authConfig);
testSysTable("SELECT * FROM sys.supervisors", null, context, PLANNER_CONFIG_DEFAULT, authConfig);
testSysTable("SELECT * FROM sys.segments", "segments", context, PLANNER_CONFIG_AUTHORIZE_SYS_TABLES, authConfig);
testSysTable("SELECT * FROM sys.servers", "servers", context, PLANNER_CONFIG_AUTHORIZE_SYS_TABLES, authConfig);
testSysTable(
"SELECT * FROM sys.server_segments",
"server_segments",
context,
PLANNER_CONFIG_AUTHORIZE_SYS_TABLES,
authConfig
);
testSysTable("SELECT * FROM sys.tasks", "tasks", context, PLANNER_CONFIG_AUTHORIZE_SYS_TABLES, authConfig);
testSysTable(
"SELECT * FROM sys.supervisors",
"supervisors",
context,
PLANNER_CONFIG_AUTHORIZE_SYS_TABLES,
authConfig
);
}
@Test
public void testQueryContext()
{
final String sql = "SELECT COUNT(*) FROM foo WHERE foo.dim1 <> 'z'";
Set<ResourceAction> requiredResources = analyzeResources(
PLANNER_CONFIG_DEFAULT,
AuthConfig.newBuilder().setAuthorizeQueryContextParams(true).build(),
sql,
ImmutableMap.of("baz", "fo", "nested-bar", ImmutableMap.of("nested-key", "nested-val")),
CalciteTests.REGULAR_USER_AUTH_RESULT
);
Assert.assertEquals(
ImmutableSet.of(
new ResourceAction(new Resource("foo", ResourceType.DATASOURCE), Action.READ),
new ResourceAction(new Resource("baz", ResourceType.QUERY_CONTEXT), Action.WRITE),
new ResourceAction(new Resource("nested-bar", ResourceType.QUERY_CONTEXT), Action.WRITE)
),
requiredResources
);
}
}

View File

@ -34,6 +34,7 @@ import org.apache.druid.data.input.MapBasedRow;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.Parser;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.expression.TestExprMacroTable;
import org.apache.druid.query.filter.DimFilter;
import org.apache.druid.query.filter.ValueMatcher;
@ -85,7 +86,7 @@ class ExpressionTestHelper
NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class))
)
),
ImmutableMap.of()
new QueryContext()
);
private final RowSignature rowSignature;

View File

@ -24,6 +24,7 @@ import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.calcite.tools.ValidationException;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryRunnerFactoryConglomerate;
import org.apache.druid.query.QuerySegmentWalker;
import org.apache.druid.sql.calcite.planner.PlannerConfig;
@ -57,7 +58,7 @@ public class ExternalTableScanRuleTest
NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class))
)
),
ImmutableMap.of()
new QueryContext()
);
plannerContext.setQueryMaker(
CalciteTests.createMockQueryMakerFactory(

View File

@ -39,6 +39,7 @@ import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.QueryContext;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.expression.DirectOperatorConversion;
@ -94,7 +95,7 @@ public class DruidRexExecutorTest extends InitializedNullHandlingTest
NamedViewSchema.NAME, new NamedViewSchema(EasyMock.createMock(ViewSchema.class))
)
),
ImmutableMap.of()
new QueryContext()
);
private final RexBuilder rexBuilder = new RexBuilder(new JavaTypeFactoryImpl());

View File

@ -29,6 +29,7 @@ import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.druid.query.QueryContext;
import org.apache.druid.sql.calcite.planner.DruidTypeSystem;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.junit.Assert;
@ -66,6 +67,7 @@ public class DruidJoinRuleTest
public void setup()
{
PlannerContext plannerContext = Mockito.mock(PlannerContext.class);
Mockito.when(plannerContext.getQueryContext()).thenReturn(Mockito.mock(QueryContext.class));
druidJoinRule = DruidJoinRule.instance(plannerContext);
}

View File

@ -195,14 +195,23 @@ public class CalciteTests
return Access.OK;
}
if (ResourceType.DATASOURCE.equals(resource.getType()) && resource.getName().equals(FORBIDDEN_DATASOURCE)) {
return new Access(false);
} else if (ResourceType.VIEW.equals(resource.getType()) && resource.getName().equals("forbiddenView")) {
return new Access(false);
} else if (ResourceType.DATASOURCE.equals(resource.getType()) || ResourceType.VIEW.equals(resource.getType())) {
return Access.OK;
} else {
return new Access(false);
switch (resource.getType()) {
case ResourceType.DATASOURCE:
if (resource.getName().equals(FORBIDDEN_DATASOURCE)) {
return new Access(false);
} else {
return Access.OK;
}
case ResourceType.VIEW:
if (resource.getName().equals("forbiddenView")) {
return new Access(false);
} else {
return Access.OK;
}
case ResourceType.QUERY_CONTEXT:
return Access.OK;
default:
return new Access(false);
}
};
}
@ -822,12 +831,21 @@ public class CalciteTests
}
public static SqlLifecycleFactory createSqlLifecycleFactory(final PlannerFactory plannerFactory)
{
return createSqlLifecycleFactory(plannerFactory, new AuthConfig());
}
public static SqlLifecycleFactory createSqlLifecycleFactory(
final PlannerFactory plannerFactory,
final AuthConfig authConfig
)
{
return new SqlLifecycleFactory(
plannerFactory,
new ServiceEmitter("dummy", "dummy", new NoopEmitter()),
new NoopRequestLogger(),
QueryStackTests.DEFAULT_NOOP_SCHEDULER
QueryStackTests.DEFAULT_NOOP_SCHEDULER,
authConfig
);
}

View File

@ -254,11 +254,13 @@ public class SqlResourceTest extends CalciteTestBase
}
};
final ServiceEmitter emitter = new NoopServiceEmitter();
final AuthConfig authConfig = new AuthConfig();
sqlLifecycleFactory = new SqlLifecycleFactory(
plannerFactory,
emitter,
testRequestLogger,
scheduler
scheduler,
authConfig
)
{
@Override
@ -269,6 +271,7 @@ public class SqlResourceTest extends CalciteTestBase
emitter,
testRequestLogger,
scheduler,
authConfig,
System.currentTimeMillis(),
System.nanoTime(),
validateAndAuthorizeLatchSupplier,
@ -1764,6 +1767,7 @@ public class SqlResourceTest extends CalciteTestBase
ServiceEmitter emitter,
RequestLogger requestLogger,
QueryScheduler queryScheduler,
AuthConfig authConfig,
long startMs,
long startNs,
SettableSupplier<NonnullPair<CountDownLatch, Boolean>> validateAndAuthorizeLatchSupplier,
@ -1772,7 +1776,7 @@ public class SqlResourceTest extends CalciteTestBase
SettableSupplier<Function<Sequence<Object[]>, Sequence<Object[]>>> sequenceMapFnSupplier
)
{
super(plannerFactory, emitter, requestLogger, queryScheduler, startMs, startNs);
super(plannerFactory, emitter, requestLogger, queryScheduler, authConfig, startMs, startNs);
this.validateAndAuthorizeLatchSupplier = validateAndAuthorizeLatchSupplier;
this.planLatchSupplier = planLatchSupplier;
this.executeLatchSupplier = executeLatchSupplier;