Mask SQL String in the MSQTaskQueryMaker for secrets (#13231)

* add test

* add masking code

* fix test

* oops

* refactor json usage

* refactor, variable update

* add test cases

* Trigger Build

* add comment to the regex

* address review comment
This commit is contained in:
Laksh Singla 2022-11-03 15:27:28 +05:30 committed by GitHub
parent ae638e338c
commit ccc55ef899
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 280 additions and 75 deletions

View File

@ -41,6 +41,7 @@ import org.apache.druid.msq.indexing.MSQDestination;
import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig; import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.TaskReportMSQDestination; import org.apache.druid.msq.indexing.TaskReportMSQDestination;
import org.apache.druid.msq.util.MSQTaskQueryMakerUtils;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
@ -48,7 +49,6 @@ import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.rpc.indexing.OverlordClient; import org.apache.druid.rpc.indexing.OverlordClient;
import org.apache.druid.segment.DimensionHandlerUtils; import org.apache.druid.segment.DimensionHandlerUtils;
import org.apache.druid.segment.IndexSpec; import org.apache.druid.segment.IndexSpec;
import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.server.QueryResponse; import org.apache.druid.server.QueryResponse;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
@ -63,14 +63,11 @@ import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public class MSQTaskQueryMaker implements QueryMaker public class MSQTaskQueryMaker implements QueryMaker
@ -92,6 +89,7 @@ public class MSQTaskQueryMaker implements QueryMaker
private final ObjectMapper jsonMapper; private final ObjectMapper jsonMapper;
private final List<Pair<Integer, String>> fieldMapping; private final List<Pair<Integer, String>> fieldMapping;
MSQTaskQueryMaker( MSQTaskQueryMaker(
@Nullable final String targetDataSource, @Nullable final String targetDataSource,
final OverlordClient overlordClient, final OverlordClient overlordClient,
@ -220,7 +218,7 @@ public class MSQTaskQueryMaker implements QueryMaker
final List<String> segmentSortOrder = MultiStageQueryContext.getSortOrder(queryContext); final List<String> segmentSortOrder = MultiStageQueryContext.getSortOrder(queryContext);
validateSegmentSortOrder( MSQTaskQueryMakerUtils.validateSegmentSortOrder(
segmentSortOrder, segmentSortOrder,
fieldMapping.stream().map(f -> f.right).collect(Collectors.toList()) fieldMapping.stream().map(f -> f.right).collect(Collectors.toList())
); );
@ -256,7 +254,7 @@ public class MSQTaskQueryMaker implements QueryMaker
final MSQControllerTask controllerTask = new MSQControllerTask( final MSQControllerTask controllerTask = new MSQControllerTask(
taskId, taskId,
querySpec, querySpec,
plannerContext.getSql(), MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(plannerContext.getSql()),
plannerContext.queryContextMap(), plannerContext.queryContextMap(),
sqlTypeNames, sqlTypeNames,
null null
@ -283,20 +281,4 @@ public class MSQTaskQueryMaker implements QueryMaker
return retVal; return retVal;
} }
static void validateSegmentSortOrder(final List<String> sortOrder, final Collection<String> allOutputColumns)
{
final Set<String> allOutputColumnsSet = new HashSet<>(allOutputColumns);
for (final String column : sortOrder) {
if (!allOutputColumnsSet.contains(column)) {
throw new IAE("Column [%s] in segment sort order does not appear in the query output", column);
}
}
if (sortOrder.size() > 0
&& allOutputColumns.contains(ColumnHolder.TIME_COLUMN_NAME)
&& !ColumnHolder.TIME_COLUMN_NAME.equals(sortOrder.get(0))) {
throw new IAE("Segment sort order must begin with column [%s]", ColumnHolder.TIME_COLUMN_NAME);
}
}
} }

View File

@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.util;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.ColumnHolder;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
public class MSQTaskQueryMakerUtils
{
public static final Set<String> SENSISTIVE_JSON_KEYS = ImmutableSet.of("accessKeyId", "secretAccessKey");
public static final Set<Pattern> SENSITIVE_KEYS_REGEX_PATTERNS = SENSISTIVE_JSON_KEYS.stream()
.map(sensitiveKey ->
Pattern.compile(
StringUtils.format(
"\\\\\"%s\\\\\"(\\s)*:(\\s)*(?<sensitive>\\{(\\s)*(\\S)+?(\\s)*\\})",
sensitiveKey
),
Pattern.CASE_INSENSITIVE
))
.collect(Collectors.toSet());
/**
* This method masks the sensitive json keys that might be present in the SQL query matching the regex
* {@code key(\s)+:(\s)+{sensitive_data}}
* The regex pattern matches a json entry of form "key":{value} and replaces it with "key":\<masked\>
* It checks the sensitive keys for the match, greedily matches the first occuring brace pair ("{" and "}")
* into a regex group named "sensitive" and performs a string replace on the group. The whitespaces are accounted
* for in the regex.
*/
public static String maskSensitiveJsonKeys(String sqlQuery)
{
StringBuilder maskedSqlQuery = new StringBuilder(sqlQuery);
for (Pattern p : SENSITIVE_KEYS_REGEX_PATTERNS) {
Matcher m = p.matcher(sqlQuery);
while (m.find()) {
String sensitiveData = m.group("sensitive");
int start = maskedSqlQuery.indexOf(sensitiveData);
int end = start + sensitiveData.length();
maskedSqlQuery.replace(start, end, "<masked>");
}
}
return maskedSqlQuery.toString();
}
/**
* Validates if each element of the sort order appears in the final output and if it is not empty then it starts with the
* __time column
*/
public static void validateSegmentSortOrder(final List<String> sortOrder, final Collection<String> allOutputColumns)
{
final Set<String> allOutputColumnsSet = new HashSet<>(allOutputColumns);
for (final String column : sortOrder) {
if (!allOutputColumnsSet.contains(column)) {
throw new IAE("Column [%s] in segment sort order does not appear in the query output", column);
}
}
if (sortOrder.size() > 0
&& allOutputColumns.contains(ColumnHolder.TIME_COLUMN_NAME)
&& !ColumnHolder.TIME_COLUMN_NAME.equals(sortOrder.get(0))) {
throw new IAE("Segment sort order must begin with column [%s]", ColumnHolder.TIME_COLUMN_NAME);
}
}
}

View File

@ -1,53 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.sql;
import com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collections;
public class MSQTaskQueryMakerTest
{
@Test
public void testValidateSegmentSortOrder()
{
// These are all OK, so validateSegmentSortOrder does nothing.
MSQTaskQueryMaker.validateSegmentSortOrder(Collections.emptyList(), ImmutableList.of("__time", "a", "b"));
MSQTaskQueryMaker.validateSegmentSortOrder(ImmutableList.of("__time"), ImmutableList.of("__time", "a", "b"));
MSQTaskQueryMaker.validateSegmentSortOrder(ImmutableList.of("__time", "b"), ImmutableList.of("__time", "a", "b"));
MSQTaskQueryMaker.validateSegmentSortOrder(ImmutableList.of("b"), ImmutableList.of("a", "b"));
// These are not OK.
Assert.assertThrows(
IllegalArgumentException.class,
() -> MSQTaskQueryMaker.validateSegmentSortOrder(ImmutableList.of("c"), ImmutableList.of("a", "b"))
);
Assert.assertThrows(
IllegalArgumentException.class,
() -> MSQTaskQueryMaker.validateSegmentSortOrder(
ImmutableList.of("b", "__time"),
ImmutableList.of("__time", "a", "b")
)
);
}
}

View File

@ -0,0 +1,183 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.util;
import com.google.common.collect.ImmutableList;
import org.junit.Assert;
import org.junit.Test;
import java.util.Collections;
public class MSQTaskQueryMakerUtilsTest
{
@Test
public void testValidateSegmentSortOrder()
{
// These are all OK, so validateSegmentSortOrder does nothing.
MSQTaskQueryMakerUtils.validateSegmentSortOrder(Collections.emptyList(), ImmutableList.of("__time", "a", "b"));
MSQTaskQueryMakerUtils.validateSegmentSortOrder(ImmutableList.of("__time"), ImmutableList.of("__time", "a", "b"));
MSQTaskQueryMakerUtils.validateSegmentSortOrder(ImmutableList.of("__time", "b"), ImmutableList.of("__time", "a", "b"));
MSQTaskQueryMakerUtils.validateSegmentSortOrder(ImmutableList.of("b"), ImmutableList.of("a", "b"));
// These are not OK.
Assert.assertThrows(
IllegalArgumentException.class,
() -> MSQTaskQueryMakerUtils.validateSegmentSortOrder(ImmutableList.of("c"), ImmutableList.of("a", "b"))
);
Assert.assertThrows(
IllegalArgumentException.class,
() -> MSQTaskQueryMakerUtils.validateSegmentSortOrder(
ImmutableList.of("b", "__time"),
ImmutableList.of("__time", "a", "b")
)
);
}
@Test
public void maskSensitiveJsonKeys()
{
String sql1 = "\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"},\\\"secretAccessKey\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"}}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"";
String sql2 = "\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\" :{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"},\\\"secretAccessKey\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"}}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"";
String sql3 = "\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\": {\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"},\\\"secretAccessKey\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"}}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"";
String sql4 = "\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\":{ \\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"},\\\"secretAccessKey\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"}}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"";
String sql5 = "\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\" },\\\"secretAccessKey\\\":{\\\"type\\\":\\\"default\\\",\\\"password\\\":\\\"secret_pass\\\"}}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"";
Assert.assertEquals(
"\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\":<masked>,\\\"secretAccessKey\\\":<masked>}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"",
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(sql1)
);
Assert.assertEquals(
"\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\" :<masked>,\\\"secretAccessKey\\\":<masked>}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"",
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(sql2)
);
Assert.assertEquals(
"\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\": <masked>,\\\"secretAccessKey\\\":<masked>}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"",
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(sql3)
);
Assert.assertEquals(
"\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\":<masked>,\\\"secretAccessKey\\\":<masked>}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"",
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(sql4)
);
Assert.assertEquals(
"\"REPLACE INTO table "
+ "OVERWRITE ALL\\n"
+ "WITH ext AS "
+ "(SELECT *\\nFROM TABLE(\\n "
+ "EXTERN(\\n '{\\\"type\\\":\\\"s3\\\",\\\"prefixes\\\":[\\\"s3://prefix\\\"],\\\"properties\\\":{\\\"accessKeyId\\\":<masked>,\\\"secretAccessKey\\\":<masked>}}',\\n"
+ "'{\\\"type\\\":\\\"json\\\"}',\\n"
+ "'[{\\\"name\\\":\\\"time\\\",\\\"type\\\":\\\"string\\\"},{\\\"name\\\":\\\"name\\\",\\\"type\\\":\\\"string\\\"}]'\\n )\\n))\\n"
+ "SELECT\\n TIME_PARSE(\\\"time\\\") AS __time,\\n name,\\n country "
+ "FROM ext\\n"
+ "PARTITIONED BY DAY\"",
MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(sql5)
);
}
}