From b36dab0fe65557e4aaf675423d61bdaa51501a71 Mon Sep 17 00:00:00 2001 From: Clint Wylie Date: Wed, 19 Aug 2020 00:18:06 -0700 Subject: [PATCH] fix connectionId issue with JDBC prepared statement queries and router (#10272) * fix router jdbc prepared statement connectionId issue * column metadata too * style * remove tls * try tls again * add keystore stuffs * use keyManager password * add unit test * simplify --- .../druid/tests/query/ITJdbcQueryTest.java | 215 ++++++++++++++++++ server/pom.xml | 5 + .../server/AsyncQueryForwardingServlet.java | 22 +- .../AsyncQueryForwardingServletTest.java | 51 +++++ 4 files changed, 289 insertions(+), 4 deletions(-) create mode 100644 integration-tests/src/test/java/org/apache/druid/tests/query/ITJdbcQueryTest.java diff --git a/integration-tests/src/test/java/org/apache/druid/tests/query/ITJdbcQueryTest.java b/integration-tests/src/test/java/org/apache/druid/tests/query/ITJdbcQueryTest.java new file mode 100644 index 00000000000..e6b2b4762b3 --- /dev/null +++ b/integration-tests/src/test/java/org/apache/druid/tests/query/ITJdbcQueryTest.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.tests.query; + +import com.google.common.collect.ImmutableList; +import com.google.inject.Inject; +import org.apache.druid.https.SSLClientConfig; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.testing.IntegrationTestingConfig; +import org.apache.druid.testing.clients.CoordinatorResourceTestClient; +import org.apache.druid.testing.guice.DruidTestModuleFactory; +import org.apache.druid.testing.utils.ITRetryUtil; +import org.apache.druid.tests.TestNGGroup; +import org.testng.Assert; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Guice; +import org.testng.annotations.Test; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.DriverManager; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +@Test(groups = TestNGGroup.QUERY) +@Guice(moduleFactory = DruidTestModuleFactory.class) +public class ITJdbcQueryTest +{ + private static final Logger LOG = new Logger(ITJdbcQueryTest.class); + private static final String WIKIPEDIA_DATA_SOURCE = "wikipedia_editstream"; + private static final String CONNECTION_TEMPLATE = "jdbc:avatica:remote:url=%s/druid/v2/sql/avatica/"; + private static final String TLS_CONNECTION_TEMPLATE = + "jdbc:avatica:remote:url=%s/druid/v2/sql/avatica/;truststore=%s;truststore_password=%s;keystore=%s;keystore_password=%s;key_password=%s"; + + private static final String QUERY_TEMPLATE = + "SELECT \"user\", SUM(\"added\"), COUNT(*)" + + "FROM \"wikipedia\" " + + "WHERE \"__time\" >= CURRENT_TIMESTAMP - INTERVAL '10' YEAR AND \"language\" = %s" + + "GROUP BY 1 ORDER BY 3 DESC LIMIT 10"; + private static final String QUERY = StringUtils.format(QUERY_TEMPLATE, "'en'"); + + private static final String QUERY_PARAMETERIZED = StringUtils.format(QUERY_TEMPLATE, "?"); + + private String[] connections; + private Properties connectionProperties; + + @Inject + private IntegrationTestingConfig config; + + @Inject + SSLClientConfig sslConfig; + + @Inject + private CoordinatorResourceTestClient coordinatorClient; + + @BeforeMethod + public void before() + { + connectionProperties = new Properties(); + connectionProperties.setProperty("user", "admin"); + connectionProperties.setProperty("password", "priest"); + connections = new String[]{ + StringUtils.format(CONNECTION_TEMPLATE, config.getRouterUrl()), + StringUtils.format(CONNECTION_TEMPLATE, config.getBrokerUrl()), + StringUtils.format( + TLS_CONNECTION_TEMPLATE, + config.getRouterTLSUrl(), + sslConfig.getTrustStorePath(), + sslConfig.getTrustStorePasswordProvider().getPassword(), + sslConfig.getKeyStorePath(), + sslConfig.getKeyStorePasswordProvider().getPassword(), + sslConfig.getKeyManagerPasswordProvider().getPassword() + ), + StringUtils.format( + TLS_CONNECTION_TEMPLATE, + config.getBrokerTLSUrl(), + sslConfig.getTrustStorePath(), + sslConfig.getTrustStorePasswordProvider().getPassword(), + sslConfig.getKeyStorePath(), + sslConfig.getKeyStorePasswordProvider().getPassword(), + sslConfig.getKeyManagerPasswordProvider().getPassword() + ) + }; + // ensure that wikipedia segments are loaded completely + ITRetryUtil.retryUntilTrue( + () -> coordinatorClient.areSegmentsLoaded(WIKIPEDIA_DATA_SOURCE), "wikipedia segment load" + ); + } + + @Test + public void testJdbcMetadata() + { + for (String url : connections) { + try (Connection connection = DriverManager.getConnection(url, connectionProperties)) { + DatabaseMetaData metadata = connection.getMetaData(); + + List catalogs = new ArrayList<>(); + ResultSet catalogsMetadata = metadata.getCatalogs(); + while (catalogsMetadata.next()) { + final String catalog = catalogsMetadata.getString(1); + catalogs.add(catalog); + } + LOG.info("catalogs %s", catalogs); + Assert.assertEquals(ImmutableList.of("druid"), catalogs); + + Set schemas = new HashSet<>(); + ResultSet schemasMetadata = metadata.getSchemas("druid", null); + while (schemasMetadata.next()) { + final String schema = schemasMetadata.getString(1); + schemas.add(schema); + } + LOG.info("'druid' catalog schemas %s", schemas); + // maybe more schemas than this, but at least should have these + Assert.assertTrue(schemas.containsAll(ImmutableList.of("INFORMATION_SCHEMA", "druid", "lookup", "sys"))); + + Set druidTables = new HashSet<>(); + ResultSet tablesMetadata = metadata.getTables("druid", "druid", null, null); + while (tablesMetadata.next()) { + final String table = tablesMetadata.getString(3); + druidTables.add(table); + } + LOG.info("'druid' schema tables %s", druidTables); + // maybe more tables than this, but at least should have these + Assert.assertTrue( + druidTables.containsAll(ImmutableList.of("twitterstream", "wikipedia", WIKIPEDIA_DATA_SOURCE)) + ); + + Set wikiColumns = new HashSet<>(); + ResultSet columnsMetadata = metadata.getColumns("druid", "druid", WIKIPEDIA_DATA_SOURCE, null); + while (columnsMetadata.next()) { + final String column = columnsMetadata.getString(4); + wikiColumns.add(column); + } + LOG.info("'%s' columns %s", WIKIPEDIA_DATA_SOURCE, wikiColumns); + // a lot more columns than this, but at least should have these + Assert.assertTrue( + wikiColumns.containsAll(ImmutableList.of("added", "city", "delta", "language")) + ); + } + catch (SQLException throwables) { + Assert.assertFalse(true, throwables.getMessage()); + } + } + } + + @Test + public void testJdbcStatementQuery() + { + for (String url : connections) { + try (Connection connection = DriverManager.getConnection(url, connectionProperties)) { + try (Statement statement = connection.createStatement()) { + final ResultSet resultSet = statement.executeQuery(QUERY); + int resultRowCount = 0; + while (resultSet.next()) { + resultRowCount++; + LOG.info("%s,%s,%s", resultSet.getString(1), resultSet.getLong(2), resultSet.getLong(3)); + } + Assert.assertEquals(10, resultRowCount); + resultSet.close(); + } + } + catch (SQLException throwables) { + Assert.assertFalse(true, throwables.getMessage()); + } + } + } + + @Test + public void testJdbcPrepareStatementQuery() + { + for (String url : connections) { + try (Connection connection = DriverManager.getConnection(url, connectionProperties)) { + try (PreparedStatement statement = connection.prepareStatement(QUERY_PARAMETERIZED)) { + statement.setString(1, "en"); + final ResultSet resultSet = statement.executeQuery(); + int resultRowCount = 0; + while (resultSet.next()) { + resultRowCount++; + LOG.info("%s,%s,%s", resultSet.getString(1), resultSet.getLong(2), resultSet.getLong(3)); + } + Assert.assertEquals(10, resultRowCount); + resultSet.close(); + } + } + catch (SQLException throwables) { + Assert.assertFalse(true, throwables.getMessage()); + } + } + } +} diff --git a/server/pom.xml b/server/pom.xml index b23bea12049..55272e650fb 100644 --- a/server/pom.xml +++ b/server/pom.xml @@ -437,6 +437,11 @@ 1.3 test + + org.apache.calcite.avatica + avatica-core + test + diff --git a/server/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java b/server/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java index 1c119f574e4..4ecb86ceb99 100644 --- a/server/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java +++ b/server/src/main/java/org/apache/druid/server/AsyncQueryForwardingServlet.java @@ -78,6 +78,9 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu @Deprecated // use SmileMediaTypes.APPLICATION_JACKSON_SMILE private static final String APPLICATION_SMILE = "application/smile"; + private static final String AVATICA_CONNECTION_ID = "connectionId"; + private static final String AVATICA_STATEMENT_HANDLE = "statementHandle"; + private static final String HOST_ATTRIBUTE = "org.apache.druid.proxy.to.host"; private static final String SCHEME_ATTRIBUTE = "org.apache.druid.proxy.to.host.scheme"; private static final String QUERY_ATTRIBUTE = "org.apache.druid.proxy.query"; @@ -422,14 +425,25 @@ public class AsyncQueryForwardingServlet extends AsyncProxyServlet implements Qu return interruptedQueryCount.get(); } - private static String getAvaticaConnectionId(Map requestMap) + @VisibleForTesting + static String getAvaticaConnectionId(Map requestMap) { - Object connectionIdObj = requestMap.get("connectionId"); + // avatica commands always have a 'connectionId'. If commands are not part of a prepared statement, this appears at + // the top level of the request, but if it is part of a statement, then it will be nested in the 'statementHandle'. + // see https://calcite.apache.org/avatica/docs/json_reference.html#requests for more details + Object connectionIdObj = requestMap.get(AVATICA_CONNECTION_ID); if (connectionIdObj == null) { - throw new IAE("Received an Avatica request without a connectionId."); + Object statementHandle = requestMap.get(AVATICA_STATEMENT_HANDLE); + if (statementHandle != null && statementHandle instanceof Map) { + connectionIdObj = ((Map) statementHandle).get(AVATICA_CONNECTION_ID); + } + } + + if (connectionIdObj == null) { + throw new IAE("Received an Avatica request without a %s.", AVATICA_CONNECTION_ID); } if (!(connectionIdObj instanceof String)) { - throw new IAE("Received an Avatica request with a non-String connectionId."); + throw new IAE("Received an Avatica request with a non-String %s.", AVATICA_CONNECTION_ID); } return (String) connectionIdObj; diff --git a/server/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java b/server/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java index ab2a322fcf9..066088c0eac 100644 --- a/server/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java +++ b/server/src/test/java/org/apache/druid/server/AsyncQueryForwardingServletTest.java @@ -19,6 +19,7 @@ package org.apache.druid.server; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -29,6 +30,8 @@ import com.google.inject.Injector; import com.google.inject.Key; import com.google.inject.Module; import com.google.inject.servlet.GuiceFilter; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.remote.Service; import org.apache.druid.common.utils.SocketUtil; import org.apache.druid.guice.GuiceInjectors; import org.apache.druid.guice.Jerseys; @@ -41,6 +44,7 @@ import org.apache.druid.guice.http.DruidHttpClientConfig; import org.apache.druid.initialization.Initialization; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; +import org.apache.druid.java.util.common.jackson.JacksonUtils; import org.apache.druid.java.util.common.lifecycle.Lifecycle; import org.apache.druid.query.DefaultGenericQueryMetricsFactory; import org.apache.druid.query.Druids; @@ -83,6 +87,8 @@ import java.net.HttpURLConnection; import java.net.URI; import java.net.URL; import java.util.Collection; +import java.util.List; +import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; import java.util.zip.Deflater; @@ -422,6 +428,51 @@ public class AsyncQueryForwardingServletTest extends BaseJettyTest ); } + @Test + public void testGetAvaticaConnectionId() throws JsonProcessingException + { + final ObjectMapper mapper = new ObjectMapper(); + final String query = "SELECT someColumn FROM druid.someTable WHERE someColumn IS NOT NULL"; + final String connectionId = "000000-0000-0000-00000000"; + final int statementId = 1337; + final int maxNumRows = 1000; + + final List jsonRequests = ImmutableList.of( + new Service.CatalogsRequest(connectionId), + new Service.SchemasRequest(connectionId, "druid", null), + new Service.TablesRequest(connectionId, "druid", "druid", null, null), + new Service.ColumnsRequest(connectionId, "druid", "druid", "someTable", null), + new Service.PrepareAndExecuteRequest( + connectionId, + statementId, + query, + maxNumRows + ), + new Service.PrepareRequest(connectionId, query, maxNumRows), + new Service.ExecuteRequest( + new Meta.StatementHandle(connectionId, statementId, null), + ImmutableList.of(), + maxNumRows + ), + new Service.CloseStatementRequest(connectionId, statementId), + new Service.CloseConnectionRequest(connectionId) + ); + + for (Service.Request request : jsonRequests) { + final String json = mapper.writeValueAsString(request); + Assert.assertEquals( + StringUtils.format("Failed %s", json), + connectionId, + AsyncQueryForwardingServlet.getAvaticaConnectionId(asMap(json, mapper)) + ); + } + } + + private static Map asMap(String json, ObjectMapper mapper) throws JsonProcessingException + { + return mapper.readValue(json, JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT); + } + private static class TestServer implements org.apache.druid.client.selector.Server {