fix connectionId issue with JDBC prepared statement queries and router (#10272)

* fix router jdbc prepared statement connectionId issue

* column metadata too

* style

* remove tls

* try tls again

* add keystore stuffs

* use keyManager password

* add unit test

* simplify
This commit is contained in:
Clint Wylie 2020-08-19 00:18:06 -07:00 committed by GitHub
parent 9a81740281
commit b36dab0fe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 289 additions and 4 deletions

View File

@ -0,0 +1,215 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.tests.query;
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import org.apache.druid.https.SSLClientConfig;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.testing.IntegrationTestingConfig;
import org.apache.druid.testing.clients.CoordinatorResourceTestClient;
import org.apache.druid.testing.guice.DruidTestModuleFactory;
import org.apache.druid.testing.utils.ITRetryUtil;
import org.apache.druid.tests.TestNGGroup;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Guice;
import org.testng.annotations.Test;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;
@Test(groups = TestNGGroup.QUERY)
@Guice(moduleFactory = DruidTestModuleFactory.class)
public class ITJdbcQueryTest
{
private static final Logger LOG = new Logger(ITJdbcQueryTest.class);
private static final String WIKIPEDIA_DATA_SOURCE = "wikipedia_editstream";
private static final String CONNECTION_TEMPLATE = "jdbc:avatica:remote:url=%s/druid/v2/sql/avatica/";
private static final String TLS_CONNECTION_TEMPLATE =
"jdbc:avatica:remote:url=%s/druid/v2/sql/avatica/;truststore=%s;truststore_password=%s;keystore=%s;keystore_password=%s;key_password=%s";
private static final String QUERY_TEMPLATE =
"SELECT \"user\", SUM(\"added\"), COUNT(*)" +
"FROM \"wikipedia\" " +
"WHERE \"__time\" >= CURRENT_TIMESTAMP - INTERVAL '10' YEAR AND \"language\" = %s" +
"GROUP BY 1 ORDER BY 3 DESC LIMIT 10";
private static final String QUERY = StringUtils.format(QUERY_TEMPLATE, "'en'");
private static final String QUERY_PARAMETERIZED = StringUtils.format(QUERY_TEMPLATE, "?");
private String[] connections;
private Properties connectionProperties;
@Inject
private IntegrationTestingConfig config;
@Inject
SSLClientConfig sslConfig;
@Inject
private CoordinatorResourceTestClient coordinatorClient;
@BeforeMethod
public void before()
{
connectionProperties = new Properties();
connectionProperties.setProperty("user", "admin");
connectionProperties.setProperty("password", "priest");
connections = new String[]{
StringUtils.format(CONNECTION_TEMPLATE, config.getRouterUrl()),
StringUtils.format(CONNECTION_TEMPLATE, config.getBrokerUrl()),
StringUtils.format(
TLS_CONNECTION_TEMPLATE,
config.getRouterTLSUrl(),
sslConfig.getTrustStorePath(),
sslConfig.getTrustStorePasswordProvider().getPassword(),
sslConfig.getKeyStorePath(),
sslConfig.getKeyStorePasswordProvider().getPassword(),
sslConfig.getKeyManagerPasswordProvider().getPassword()
),
StringUtils.format(
TLS_CONNECTION_TEMPLATE,
config.getBrokerTLSUrl(),
sslConfig.getTrustStorePath(),
sslConfig.getTrustStorePasswordProvider().getPassword(),
sslConfig.getKeyStorePath(),
sslConfig.getKeyStorePasswordProvider().getPassword(),
sslConfig.getKeyManagerPasswordProvider().getPassword()
)
};
// ensure that wikipedia segments are loaded completely
ITRetryUtil.retryUntilTrue(
() -> coordinatorClient.areSegmentsLoaded(WIKIPEDIA_DATA_SOURCE), "wikipedia segment load"
);
}
@Test
public void testJdbcMetadata()
{
for (String url : connections) {
try (Connection connection = DriverManager.getConnection(url, connectionProperties)) {
DatabaseMetaData metadata = connection.getMetaData();
List<String> catalogs = new ArrayList<>();
ResultSet catalogsMetadata = metadata.getCatalogs();
while (catalogsMetadata.next()) {
final String catalog = catalogsMetadata.getString(1);
catalogs.add(catalog);
}
LOG.info("catalogs %s", catalogs);
Assert.assertEquals(ImmutableList.of("druid"), catalogs);
Set<String> schemas = new HashSet<>();
ResultSet schemasMetadata = metadata.getSchemas("druid", null);
while (schemasMetadata.next()) {
final String schema = schemasMetadata.getString(1);
schemas.add(schema);
}
LOG.info("'druid' catalog schemas %s", schemas);
// maybe more schemas than this, but at least should have these
Assert.assertTrue(schemas.containsAll(ImmutableList.of("INFORMATION_SCHEMA", "druid", "lookup", "sys")));
Set<String> druidTables = new HashSet<>();
ResultSet tablesMetadata = metadata.getTables("druid", "druid", null, null);
while (tablesMetadata.next()) {
final String table = tablesMetadata.getString(3);
druidTables.add(table);
}
LOG.info("'druid' schema tables %s", druidTables);
// maybe more tables than this, but at least should have these
Assert.assertTrue(
druidTables.containsAll(ImmutableList.of("twitterstream", "wikipedia", WIKIPEDIA_DATA_SOURCE))
);
Set<String> wikiColumns = new HashSet<>();
ResultSet columnsMetadata = metadata.getColumns("druid", "druid", WIKIPEDIA_DATA_SOURCE, null);
while (columnsMetadata.next()) {
final String column = columnsMetadata.getString(4);
wikiColumns.add(column);
}
LOG.info("'%s' columns %s", WIKIPEDIA_DATA_SOURCE, wikiColumns);
// a lot more columns than this, but at least should have these
Assert.assertTrue(
wikiColumns.containsAll(ImmutableList.of("added", "city", "delta", "language"))
);
}
catch (SQLException throwables) {
Assert.assertFalse(true, throwables.getMessage());
}
}
}
@Test
public void testJdbcStatementQuery()
{
for (String url : connections) {
try (Connection connection = DriverManager.getConnection(url, connectionProperties)) {
try (Statement statement = connection.createStatement()) {
final ResultSet resultSet = statement.executeQuery(QUERY);
int resultRowCount = 0;
while (resultSet.next()) {
resultRowCount++;
LOG.info("%s,%s,%s", resultSet.getString(1), resultSet.getLong(2), resultSet.getLong(3));
}
Assert.assertEquals(10, resultRowCount);
resultSet.close();
}
}
catch (SQLException throwables) {
Assert.assertFalse(true, throwables.getMessage());
}
}
}
@Test
public void testJdbcPrepareStatementQuery()
{
for (String url : connections) {
try (Connection connection = DriverManager.getConnection(url, connectionProperties)) {
try (PreparedStatement statement = connection.prepareStatement(QUERY_PARAMETERIZED)) {
statement.setString(1, "en");
final ResultSet resultSet = statement.executeQuery();
int resultRowCount = 0;
while (resultSet.next()) {
resultRowCount++;
LOG.info("%s,%s,%s", resultSet.getString(1), resultSet.getLong(2), resultSet.getLong(3));
}
Assert.assertEquals(10, resultRowCount);
resultSet.close();
}
}
catch (SQLException throwables) {
Assert.assertFalse(true, throwables.getMessage());
}
}
}
}

View File

@ -437,6 +437,11 @@
<version>1.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.calcite.avatica</groupId>
<artifactId>avatica-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>

View File

@ -78,6 +78,9 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
@Deprecated // use SmileMediaTypes.APPLICATION_JACKSON_SMILE
private static final String APPLICATION_SMILE = "application/smile";
private static final String AVATICA_CONNECTION_ID = "connectionId";
private static final String AVATICA_STATEMENT_HANDLE = "statementHandle";
private static final String HOST_ATTRIBUTE = "org.apache.druid.proxy.to.host";
private static final String SCHEME_ATTRIBUTE = "org.apache.druid.proxy.to.host.scheme";
private static final String QUERY_ATTRIBUTE = "org.apache.druid.proxy.query";
@ -422,14 +425,25 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu
return interruptedQueryCount.get();
}
private static String getAvaticaConnectionId(Map<String, Object> requestMap)
@VisibleForTesting
static String getAvaticaConnectionId(Map<String, Object> requestMap)
{
Object connectionIdObj = requestMap.get("connectionId");
// avatica commands always have a 'connectionId'. If commands are not part of a prepared statement, this appears at
// the top level of the request, but if it is part of a statement, then it will be nested in the 'statementHandle'.
// see https://calcite.apache.org/avatica/docs/json_reference.html#requests for more details
Object connectionIdObj = requestMap.get(AVATICA_CONNECTION_ID);
if (connectionIdObj == null) {
throw new IAE("Received an Avatica request without a connectionId.");
Object statementHandle = requestMap.get(AVATICA_STATEMENT_HANDLE);
if (statementHandle != null && statementHandle instanceof Map) {
connectionIdObj = ((Map) statementHandle).get(AVATICA_CONNECTION_ID);
}
}
if (connectionIdObj == null) {
throw new IAE("Received an Avatica request without a %s.", AVATICA_CONNECTION_ID);
}
if (!(connectionIdObj instanceof String)) {
throw new IAE("Received an Avatica request with a non-String connectionId.");
throw new IAE("Received an Avatica request with a non-String %s.", AVATICA_CONNECTION_ID);
}
return (String) connectionIdObj;

View File

@ -19,6 +19,7 @@
package org.apache.druid.server;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
@ -29,6 +30,8 @@ import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.servlet.GuiceFilter;
import org.apache.calcite.avatica.Meta;
import org.apache.calcite.avatica.remote.Service;
import org.apache.druid.common.utils.SocketUtil;
import org.apache.druid.guice.GuiceInjectors;
import org.apache.druid.guice.Jerseys;
@ -41,6 +44,7 @@ import org.apache.druid.guice.http.DruidHttpClientConfig;
import org.apache.druid.initialization.Initialization;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.jackson.JacksonUtils;
import org.apache.druid.java.util.common.lifecycle.Lifecycle;
import org.apache.druid.query.DefaultGenericQueryMetricsFactory;
import org.apache.druid.query.Druids;
@ -83,6 +87,8 @@ import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URL;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;
import java.util.zip.Deflater;
@ -422,6 +428,51 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest
);
}
@Test
public void testGetAvaticaConnectionId() throws JsonProcessingException
{
final ObjectMapper mapper = new ObjectMapper();
final String query = "SELECT someColumn FROM druid.someTable WHERE someColumn IS NOT NULL";
final String connectionId = "000000-0000-0000-00000000";
final int statementId = 1337;
final int maxNumRows = 1000;
final List<? extends Service.Request> jsonRequests = ImmutableList.of(
new Service.CatalogsRequest(connectionId),
new Service.SchemasRequest(connectionId, "druid", null),
new Service.TablesRequest(connectionId, "druid", "druid", null, null),
new Service.ColumnsRequest(connectionId, "druid", "druid", "someTable", null),
new Service.PrepareAndExecuteRequest(
connectionId,
statementId,
query,
maxNumRows
),
new Service.PrepareRequest(connectionId, query, maxNumRows),
new Service.ExecuteRequest(
new Meta.StatementHandle(connectionId, statementId, null),
ImmutableList.of(),
maxNumRows
),
new Service.CloseStatementRequest(connectionId, statementId),
new Service.CloseConnectionRequest(connectionId)
);
for (Service.Request request : jsonRequests) {
final String json = mapper.writeValueAsString(request);
Assert.assertEquals(
StringUtils.format("Failed %s", json),
connectionId,
AsyncQueryForwardingServlet.getAvaticaConnectionId(asMap(json, mapper))
);
}
}
private static Map<String, Object> asMap(String json, ObjectMapper mapper) throws JsonProcessingException
{
return mapper.readValue(json, JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT);
}
private static class TestServer implements org.apache.druid.client.selector.Server
{