From 81074411a99aadd8a899ceae7d8d12f2de347ac6 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Wed, 12 Apr 2023 22:39:39 -0700 Subject: [PATCH] MSQ: Support multiple result columns with the same name. (#14025) * MSQ: Support multiple result columns with the same name. This is allowed in SQL, and is supported by the regular SQL endpoint. We retain a validation that INSERT ... SELECT does not allow multiple columns with the same name, because column names in segments must be unique. --- .../apache/druid/msq/exec/ControllerImpl.java | 100 ++++++++----- .../druid/msq/indexing/ColumnMappings.java | 137 +++++++++++++----- .../msq/indexing/report/MSQResultsReport.java | 69 ++++++++- .../druid/msq/sql/MSQTaskQueryMaker.java | 3 - .../druid/msq/sql/MSQTaskSqlEngine.java | 7 +- .../apache/druid/msq/exec/MSQInsertTest.java | 24 +++ .../apache/druid/msq/exec/MSQSelectTest.java | 68 +++++++++ .../indexing/report/MSQTaskReportTest.java | 3 +- .../druid/msq/test/ExtractResultsFactory.java | 9 +- .../apache/druid/msq/test/MSQTestBase.java | 57 +++++--- .../testing/utils/MsqTestQueryHelper.java | 5 +- .../apache/druid/utils/CollectionUtils.java | 23 +++ .../druid/utils/CollectionUtilsTest.java | 64 ++++++-- 13 files changed, 436 insertions(+), 133 deletions(-) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 9a928ffa344..e2ec3ced216 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -36,6 +36,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntArraySet; +import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntSet; import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.data.input.StringTuple; @@ -186,6 +187,7 @@ import org.apache.druid.timeline.partition.DimensionRangeShardSpec; import org.apache.druid.timeline.partition.NumberedPartialShardSpec; import org.apache.druid.timeline.partition.NumberedShardSpec; import org.apache.druid.timeline.partition.ShardSpec; +import org.apache.druid.utils.CollectionUtils; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -1435,7 +1437,7 @@ public class ControllerImpl implements Controller final List retVal = new ArrayList<>(); while (!cursor.isDone()) { - final Object[] row = new Object[columnMappings.getMappings().size()]; + final Object[] row = new Object[columnMappings.size()]; for (int i = 0; i < row.length; i++) { row[i] = selectors.get(i).getObject(); } @@ -1499,6 +1501,8 @@ public class ControllerImpl implements Controller ) { final MSQTuningConfig tuningConfig = querySpec.getTuningConfig(); + final ColumnMappings columnMappings = querySpec.getColumnMappings(); + final Query queryToPlan; final ShuffleSpecFactory shuffleSpecFactory; if (MSQControllerTask.isIngestion(querySpec)) { @@ -1508,25 +1512,33 @@ public class ControllerImpl implements Controller tuningConfig.getRowsPerSegment(), aggregate ); + + if (!columnMappings.hasUniqueOutputColumnNames()) { + // We do not expect to hit this case in production, because the SQL validator checks that column names + // are unique for INSERT and REPLACE statements (i.e. anything where MSQControllerTask.isIngestion would + // be true). This check is here as defensive programming. + throw new ISE("Column names are not unique: [%s]", columnMappings.getOutputColumnNames()); + } + + if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { + // We know there's a single time column, because we've checked columnMappings.hasUniqueOutputColumnNames(). + final int timeColumn = columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME).getInt(0); + queryToPlan = querySpec.getQuery().withOverriddenContext( + ImmutableMap.of( + QueryKitUtils.CTX_TIME_COLUMN_NAME, + columnMappings.getQueryColumnName(timeColumn) + ) + ); + } else { + queryToPlan = querySpec.getQuery(); + } } else if (querySpec.getDestination() instanceof TaskReportMSQDestination) { shuffleSpecFactory = ShuffleSpecFactories.singlePartition(); + queryToPlan = querySpec.getQuery(); } else { throw new ISE("Unsupported destination [%s]", querySpec.getDestination()); } - final Query queryToPlan; - - if (querySpec.getColumnMappings().hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { - queryToPlan = querySpec.getQuery().withOverriddenContext( - ImmutableMap.of( - QueryKitUtils.CTX_TIME_COLUMN_NAME, - querySpec.getColumnMappings().getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME) - ) - ); - } else { - queryToPlan = querySpec.getQuery(); - } - final QueryDefinition queryDef; try { @@ -1550,7 +1562,6 @@ public class ControllerImpl implements Controller if (MSQControllerTask.isIngestion(querySpec)) { final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature(); final ClusterBy queryClusterBy = queryDef.getFinalStageDefinition().getClusterBy(); - final ColumnMappings columnMappings = querySpec.getColumnMappings(); // Find the stage that provides shuffled input to the final segment-generation stage. StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition(); @@ -1679,8 +1690,10 @@ public class ControllerImpl implements Controller */ private static boolean timeIsGroupByDimension(GroupByQuery groupByQuery, ColumnMappings columnMappings) { - if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { - final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME); + final IntList positions = columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME); + + if (positions.size() == 1) { + final String queryTimeColumn = columnMappings.getQueryColumnName(positions.getInt(0)); return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD)); } else { return false; @@ -1740,7 +1753,7 @@ public class ControllerImpl implements Controller for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) { final KeyColumn column = clusterByColumns.get(i); - final List outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName()); + final IntList outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName()); // DimensionRangeShardSpec only handles ascending order. if (column.order() != KeyOrder.ASCENDING) { @@ -1759,7 +1772,7 @@ public class ControllerImpl implements Controller return Collections.emptyList(); } - shardColumns.add(outputColumns.get(0)); + shardColumns.add(columnMappings.getOutputColumnName(outputColumns.getInt(0))); } return shardColumns; @@ -1830,7 +1843,10 @@ public class ControllerImpl implements Controller throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName())); } - outputColumnsInOrder.addAll(columnMappings.getOutputColumnsForQueryColumn(clusterByColumn.columnName())); + final IntList outputColumns = columnMappings.getOutputColumnsForQueryColumn(clusterByColumn.columnName()); + for (final int outputColumn : outputColumns) { + outputColumnsInOrder.add(columnMappings.getOutputColumnName(outputColumn)); + } } // Then all other columns. @@ -1841,13 +1857,17 @@ public class ControllerImpl implements Controller if (isRollupQuery) { // Populate aggregators from the native query when doing an ingest in rollup mode. for (AggregatorFactory aggregatorFactory : ((GroupByQuery) query).getAggregatorSpecs()) { - String outputColumn = Iterables.getOnlyElement(columnMappings.getOutputColumnsForQueryColumn(aggregatorFactory.getName())); - if (outputColumnAggregatorFactories.containsKey(outputColumn)) { - throw new ISE("There can only be one aggregator factory for column [%s].", outputColumn); + final int outputColumn = CollectionUtils.getOnlyElement( + columnMappings.getOutputColumnsForQueryColumn(aggregatorFactory.getName()), + xs -> new ISE("Expected single output for query column[%s] but got[%s]", aggregatorFactory.getName(), xs) + ); + final String outputColumnName = columnMappings.getOutputColumnName(outputColumn); + if (outputColumnAggregatorFactories.containsKey(outputColumnName)) { + throw new ISE("There can only be one aggregation for column [%s].", outputColumn); } else { outputColumnAggregatorFactories.put( - outputColumn, - aggregatorFactory.withName(outputColumn).getCombiningFactory() + outputColumnName, + aggregatorFactory.withName(outputColumnName).getCombiningFactory() ); } } @@ -1856,13 +1876,19 @@ public class ControllerImpl implements Controller // Each column can be of either time, dimension, aggregator. For this method. we can ignore the time column. // For non-complex columns, If the aggregator factory of the column is not available, we treat the column as // a dimension. For complex columns, certains hacks are in place. - for (final String outputColumn : outputColumnsInOrder) { - final String queryColumn = columnMappings.getQueryColumnForOutputColumn(outputColumn); + for (final String outputColumnName : outputColumnsInOrder) { + // CollectionUtils.getOnlyElement because this method is only called during ingestion, where we require + // that output names be unique. + final int outputColumn = CollectionUtils.getOnlyElement( + columnMappings.getOutputColumnsByName(outputColumnName), + xs -> new ISE("Expected single output column for name [%s], but got [%s]", outputColumnName, xs) + ); + final String queryColumn = columnMappings.getQueryColumnName(outputColumn); final ColumnType type = querySignature.getColumnType(queryColumn) - .orElseThrow(() -> new ISE("No type for column [%s]", outputColumn)); + .orElseThrow(() -> new ISE("No type for column [%s]", outputColumnName)); - if (!outputColumn.equals(ColumnHolder.TIME_COLUMN_NAME)) { + if (!outputColumnName.equals(ColumnHolder.TIME_COLUMN_NAME)) { if (!type.is(ValueType.COMPLEX)) { // non complex columns @@ -1870,21 +1896,21 @@ public class ControllerImpl implements Controller dimensions, aggregators, outputColumnAggregatorFactories, - outputColumn, + outputColumnName, type ); } else { // complex columns only if (DimensionHandlerUtils.DIMENSION_HANDLER_PROVIDERS.containsKey(type.getComplexTypeName())) { - dimensions.add(DimensionSchemaUtils.createDimensionSchema(outputColumn, type)); + dimensions.add(DimensionSchemaUtils.createDimensionSchema(outputColumnName, type)); } else if (!isRollupQuery) { - aggregators.add(new PassthroughAggregatorFactory(outputColumn, type.getComplexTypeName())); + aggregators.add(new PassthroughAggregatorFactory(outputColumnName, type.getComplexTypeName())); } else { populateDimensionsAndAggregators( dimensions, aggregators, outputColumnAggregatorFactories, - outputColumn, + outputColumnName, type ); } @@ -1972,12 +1998,14 @@ public class ControllerImpl implements Controller ) { final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature(); - final RowSignature.Builder mappedSignature = RowSignature.builder(); + final ImmutableList.Builder mappedSignature = ImmutableList.builder(); for (final ColumnMapping mapping : columnMappings.getMappings()) { mappedSignature.add( - mapping.getOutputColumn(), - querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) + new MSQResultsReport.ColumnAndType( + mapping.getOutputColumn(), + querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) + ) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/ColumnMappings.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/ColumnMappings.java index fddfbeefc76..5dc502f7959 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/ColumnMappings.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/ColumnMappings.java @@ -22,12 +22,12 @@ package org.apache.druid.msq.indexing; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.base.Preconditions; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntList; +import it.unimi.dsi.fastutil.ints.IntLists; import org.apache.druid.java.util.common.IAE; -import org.apache.druid.java.util.common.ISE; import org.apache.druid.segment.column.RowSignature; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -36,23 +36,32 @@ import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; +/** + * Maps column names from {@link MSQSpec#getQuery()} to output names desired by the user, in the order + * desired by the user. + * + * The {@link MSQSpec#getQuery()} is translated by {@link org.apache.druid.msq.querykit.QueryKit} into + * a {@link org.apache.druid.msq.kernel.QueryDefinition}. So, this class also represents mappings from + * {@link org.apache.druid.msq.kernel.QueryDefinition#getFinalStageDefinition()} into the output names desired + * by the user. + */ public class ColumnMappings { private final List mappings; - private final Map outputToQueryColumnMap; - private final Map> queryToOutputColumnsMap; + private final Map outputColumnNameToPositionMap; + private final Map queryColumnNameToPositionMap; @JsonCreator public ColumnMappings(final List mappings) { - this.mappings = validateNoDuplicateOutputColumns(Preconditions.checkNotNull(mappings, "mappings")); - this.outputToQueryColumnMap = new HashMap<>(); - this.queryToOutputColumnsMap = new HashMap<>(); + this.mappings = Preconditions.checkNotNull(mappings, "mappings"); + this.outputColumnNameToPositionMap = new HashMap<>(); + this.queryColumnNameToPositionMap = new HashMap<>(); - for (final ColumnMapping mapping : mappings) { - outputToQueryColumnMap.put(mapping.getOutputColumn(), mapping.getQueryColumn()); - queryToOutputColumnsMap.computeIfAbsent(mapping.getQueryColumn(), k -> new ArrayList<>()) - .add(mapping.getOutputColumn()); + for (int i = 0; i < mappings.size(); i++) { + final ColumnMapping mapping = mappings.get(i); + outputColumnNameToPositionMap.computeIfAbsent(mapping.getOutputColumn(), k -> new IntArrayList()).add(i); + queryColumnNameToPositionMap.computeIfAbsent(mapping.getQueryColumn(), k -> new IntArrayList()).add(i); } } @@ -66,34 +75,95 @@ public class ColumnMappings ); } + /** + * Number of output columns. + */ + public int size() + { + return mappings.size(); + } + + /** + * All output column names, in order. Some names may appear more than once, unless + * {@link #hasUniqueOutputColumnNames()} is true. + */ public List getOutputColumnNames() { return mappings.stream().map(ColumnMapping::getOutputColumn).collect(Collectors.toList()); } - public boolean hasOutputColumn(final String columnName) + /** + * Whether output column names from {@link #getOutputColumnNames()} are all unique. + */ + public boolean hasUniqueOutputColumnNames() { - return outputToQueryColumnMap.containsKey(columnName); + final Set encountered = new HashSet<>(); + + for (final ColumnMapping mapping : mappings) { + if (!encountered.add(mapping.getOutputColumn())) { + return false; + } + } + + return true; } - public String getQueryColumnForOutputColumn(final String outputColumn) + /** + * Whether a particular output column name exists. + */ + public boolean hasOutputColumn(final String outputColumnName) { - final String queryColumn = outputToQueryColumnMap.get(outputColumn); - if (queryColumn != null) { - return queryColumn; - } else { - throw new IAE("No such output column [%s]", outputColumn); - } + return outputColumnNameToPositionMap.containsKey(outputColumnName); } - public List getOutputColumnsForQueryColumn(final String queryColumn) + /** + * Query column name for a particular output column position. + * + * @throws IllegalArgumentException if the output column position is out of range + */ + public String getQueryColumnName(final int outputColumn) { - final List outputColumns = queryToOutputColumnsMap.get(queryColumn); - if (outputColumns != null) { - return outputColumns; - } else { - return Collections.emptyList(); + if (outputColumn < 0 || outputColumn >= mappings.size()) { + throw new IAE("Output column position[%d] out of range", outputColumn); } + + return mappings.get(outputColumn).getQueryColumn(); + } + + /** + * Output column name for a particular output column position. + * + * @throws IllegalArgumentException if the output column position is out of range + */ + public String getOutputColumnName(final int outputColumn) + { + if (outputColumn < 0 || outputColumn >= mappings.size()) { + throw new IAE("Output column position[%d] out of range", outputColumn); + } + + return mappings.get(outputColumn).getOutputColumn(); + } + + /** + * Output column positions for a particular output column name. + */ + public IntList getOutputColumnsByName(final String outputColumnName) + { + return outputColumnNameToPositionMap.getOrDefault(outputColumnName, IntLists.emptyList()); + } + + /** + * Output column positions for a particular query column name. + */ + public IntList getOutputColumnsForQueryColumn(final String queryColumnName) + { + final IntList outputColumnPositions = queryColumnNameToPositionMap.get(queryColumnName); + + if (outputColumnPositions == null) { + return IntLists.emptyList(); + } + + return outputColumnPositions; } @JsonValue @@ -128,17 +198,4 @@ public class ColumnMappings "mappings=" + mappings + '}'; } - - private static List validateNoDuplicateOutputColumns(final List mappings) - { - final Set encountered = new HashSet<>(); - - for (final ColumnMapping mapping : mappings) { - if (!encountered.add(mapping.getOutputColumn())) { - throw new ISE("Duplicate output column [%s]", mapping.getOutputColumn()); - } - } - - return mappings; - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java index 3d75e3986ca..58911389b36 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/report/MSQResultsReport.java @@ -26,20 +26,25 @@ import com.google.common.base.Preconditions; import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielders; -import org.apache.druid.segment.column.RowSignature; +import org.apache.druid.segment.column.ColumnType; import javax.annotation.Nullable; import java.util.List; +import java.util.Objects; public class MSQResultsReport { - private final RowSignature signature; + /** + * Like {@link org.apache.druid.segment.column.RowSignature}, but allows duplicate column names for compatibility + * with SQL (which also allows duplicate column names in query results). + */ + private final List signature; @Nullable private final List sqlTypeNames; private final Yielder resultYielder; public MSQResultsReport( - final RowSignature signature, + final List signature, @Nullable final List sqlTypeNames, final Yielder resultYielder ) @@ -54,7 +59,7 @@ public class MSQResultsReport */ @JsonCreator static MSQResultsReport fromJson( - @JsonProperty("signature") final RowSignature signature, + @JsonProperty("signature") final List signature, @JsonProperty("sqlTypeNames") @Nullable final List sqlTypeNames, @JsonProperty("results") final List results ) @@ -63,7 +68,7 @@ public class MSQResultsReport } @JsonProperty("signature") - public RowSignature getSignature() + public List getSignature() { return signature; } @@ -81,4 +86,58 @@ public class MSQResultsReport { return resultYielder; } + + public static class ColumnAndType + { + private final String name; + private final ColumnType type; + + @JsonCreator + public ColumnAndType( + @JsonProperty("name") String name, + @JsonProperty("type") ColumnType type + ) + { + this.name = name; + this.type = type; + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public ColumnType getType() + { + return type; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ColumnAndType that = (ColumnAndType) o; + return Objects.equals(name, that.name) && Objects.equals(type, that.type); + } + + @Override + public int hashCode() + { + return Objects.hash(name, type); + } + + @Override + public String toString() + { + return name + ":" + type; + } + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java index 35b439e1ecc..783b3714fc0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java @@ -183,9 +183,6 @@ public class MSQTaskQueryMaker implements QueryMaker final List columnMappings = new ArrayList<>(); for (final Pair entry : fieldMapping) { - // Note: SQL generally allows output columns to be duplicates, but MSQTaskSqlEngine.validateNoDuplicateAliases - // will prevent duplicate output columns from appearing here. So no need to worry about it. - final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey()); final String outputColumns = entry.getValue(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java index b2c4fab0ee8..94c2532ca79 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java @@ -173,8 +173,6 @@ public class MSQTaskSqlEngine implements SqlEngine final PlannerContext plannerContext ) throws ValidationException { - validateNoDuplicateAliases(fieldMappings); - if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) { throw new ValidationException( StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) @@ -248,7 +246,8 @@ public class MSQTaskSqlEngine implements SqlEngine } /** - * SQL allows multiple output columns with the same name, but multi-stage queries doesn't. + * SQL allows multiple output columns with the same name. However, we don't allow this for INSERT or REPLACE + * queries, because we use these output names to generate columns in segments. They must be unique. */ private static void validateNoDuplicateAliases(final List> fieldMappings) throws ValidationException @@ -257,7 +256,7 @@ public class MSQTaskSqlEngine implements SqlEngine for (final Pair field : fieldMappings) { if (!aliasesSeen.add(field.right)) { - throw new ValidationException("Duplicate field in SELECT: " + field.right); + throw new ValidationException("Duplicate field in SELECT: [" + field.right + "]"); } } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java index 82f35838f84..b20e44313ad 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQInsertTest.java @@ -878,6 +878,30 @@ public class MSQInsertTest extends MSQTestBase .verifyResults(); } + @Test + public void testInsertDuplicateColumnNames() + { + testIngestQuery() + .setSql(" insert into foo1 SELECT\n" + + " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n" + + " namespace,\n" + + " \"user\" AS namespace\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [\"ignored\"],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"namespace\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}, {\"name\": \"__bucket\", \"type\": \"string\"}]'\n" + + " )\n" + + ") PARTITIONED by day") + .setQueryContext(context) + .setExpectedValidationErrorMatcher(CoreMatchers.allOf( + CoreMatchers.instanceOf(SqlPlanningException.class), + ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( + "Duplicate field in SELECT: [namespace]")) + )) + .verifyPlanningErrors(); + } + @Test public void testInsertQueryWithInvalidSubtaskCount() { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index d6524466266..4d5a2c0eee3 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -35,6 +35,7 @@ import org.apache.druid.msq.indexing.ColumnMappings; import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQTuningConfig; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; +import org.apache.druid.msq.indexing.report.MSQResultsReport; import org.apache.druid.msq.test.CounterSnapshotMatcher; import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestFileUtils; @@ -245,6 +246,73 @@ public class MSQSelectTest extends MSQTestBase .verifyResults(); } + @Test + public void testSelectOnFooDuplicateColumnNames() + { + // Duplicate column names are OK in SELECT statements. + + final RowSignature expectedScanSignature = + RowSignature.builder() + .add("cnt", ColumnType.LONG) + .add("dim1", ColumnType.STRING) + .build(); + + final ColumnMappings expectedColumnMappings = new ColumnMappings( + ImmutableList.of( + new ColumnMapping("cnt", "x"), + new ColumnMapping("dim1", "x") + ) + ); + + final List expectedOutputSignature = + ImmutableList.of( + new MSQResultsReport.ColumnAndType("x", ColumnType.LONG), + new MSQResultsReport.ColumnAndType("x", ColumnType.STRING) + ); + + testSelectQuery() + .setSql("select cnt AS x, dim1 AS x from foo") + .setExpectedMSQSpec( + MSQSpec.builder() + .query( + newScanQueryBuilder() + .dataSource(CalciteTests.DATASOURCE1) + .intervals(querySegmentSpec(Filtration.eternity())) + .columns("cnt", "dim1") + .context(defaultScanQueryContext(context, expectedScanSignature)) + .build() + ) + .columnMappings(expectedColumnMappings) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .build() + ) + .setQueryContext(context) + .setExpectedRowSignature(expectedOutputSignature) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().totalFiles(1), + 0, 0, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(6).frames(1), + 0, 0, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(6).frames(1), + 0, 0, "shuffle" + ) + .setExpectedResultRows(ImmutableList.of( + new Object[]{1L, !useDefault ? "" : null}, + new Object[]{1L, "10.1"}, + new Object[]{1L, "2"}, + new Object[]{1L, "1"}, + new Object[]{1L, "def"}, + new Object[]{1L, "abc"} + )).verifyResults(); + } + @Test public void testSelectOnFooWhereMatchesNoSegments() { diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java index 366e9c50993..b4baaea45a7 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/indexing/report/MSQTaskReportTest.java @@ -51,6 +51,7 @@ import org.junit.rules.TemporaryFolder; import java.io.File; import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -101,7 +102,7 @@ public class MSQTaskReportTest ), new CounterSnapshotsTree(), new MSQResultsReport( - RowSignature.builder().add("s", ColumnType.STRING).build(), + Collections.singletonList(new MSQResultsReport.ColumnAndType("s", ColumnType.STRING)), ImmutableList.of("VARCHAR"), Yielders.each(Sequences.simple(results)) ) diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/ExtractResultsFactory.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/ExtractResultsFactory.java index 7b8079dbb8a..7b694e1384b 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/ExtractResultsFactory.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/ExtractResultsFactory.java @@ -20,17 +20,14 @@ package org.apache.druid.msq.test; import org.apache.druid.java.util.common.ISE; -import org.apache.druid.java.util.common.Pair; import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.calcite.QueryTestBuilder; import org.apache.druid.sql.calcite.QueryTestRunner; import org.junit.Assert; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.function.Supplier; /** @@ -98,11 +95,11 @@ public class ExtractResultsFactory implements QueryTestRunner.QueryRunStepFactor if (!payload.getStatus().getStatus().isComplete()) { throw new ISE("Query task [%s] should have finished", taskId); } - Optional>> signatureListPair = MSQTestBase.getSignatureWithRows(payload.getResults()); - if (!signatureListPair.isPresent()) { + final List resultRows = MSQTestBase.getRows(payload.getResults()); + if (resultRows == null) { throw new ISE("Results report not present in the task's report payload"); } - extractedResults.add(results.withResults(signatureListPair.get().rhs)); + extractedResults.add(results.withResults(resultRows)); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index 1ee2bb0ebc2..a87f30c6784 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -179,7 +179,6 @@ import org.mockito.Mockito; import javax.annotation.Nonnull; import javax.annotation.Nullable; - import java.io.Closeable; import java.io.File; import java.io.IOException; @@ -191,7 +190,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; @@ -775,12 +773,12 @@ public class MSQTestBase extends BaseCalciteQueryTest ); } - public static Optional>> getSignatureWithRows(MSQResultsReport resultsReport) + @Nullable + public static List getRows(@Nullable MSQResultsReport resultsReport) { if (resultsReport == null) { - return Optional.empty(); + return null; } else { - RowSignature rowSignature = resultsReport.getSignature(); Yielder yielder = resultsReport.getResultYielder(); List rows = new ArrayList<>(); while (!yielder.isDone()) { @@ -794,7 +792,7 @@ public class MSQTestBase extends BaseCalciteQueryTest throw new ISE("Unable to get results from the report"); } - return Optional.of(new Pair>(rowSignature, rows)); + return rows; } } @@ -802,7 +800,7 @@ public class MSQTestBase extends BaseCalciteQueryTest { protected String sql = null; protected Map queryContext = DEFAULT_MSQ_CONTEXT; - protected RowSignature expectedRowSignature = null; + protected List expectedRowSignature = null; protected MSQSpec expectedMSQSpec = null; protected MSQTuningConfig expectedTuningConfig = null; protected Set expectedSegments = null; @@ -829,10 +827,17 @@ public class MSQTestBase extends BaseCalciteQueryTest return asBuilder(); } + public Builder setExpectedRowSignature(List expectedRowSignature) + { + Preconditions.checkArgument(!expectedRowSignature.isEmpty(), "Row signature cannot be empty"); + this.expectedRowSignature = expectedRowSignature; + return asBuilder(); + } + public Builder setExpectedRowSignature(RowSignature expectedRowSignature) { Preconditions.checkArgument(!expectedRowSignature.equals(RowSignature.empty()), "Row signature cannot be empty"); - this.expectedRowSignature = expectedRowSignature; + this.expectedRowSignature = resultSignatureFromRowSignature(expectedRowSignature); return asBuilder(); } @@ -1100,7 +1105,7 @@ public class MSQTestBase extends BaseCalciteQueryTest final StorageAdapter storageAdapter = new QueryableIndexStorageAdapter(queryableIndex); // assert rowSignature - Assert.assertEquals(expectedRowSignature, storageAdapter.getRowSignature()); + Assert.assertEquals(expectedRowSignature, resultSignatureFromRowSignature(storageAdapter.getRowSignature())); // assert rollup Assert.assertEquals(expectedRollUp, queryableIndex.getMetadata().isRollup()); @@ -1172,7 +1177,7 @@ public class MSQTestBase extends BaseCalciteQueryTest Assert.assertTrue(segmentIdVsOutputRowsMap.get(diskSegment).contains(Arrays.asList(row))); } } - + // Assert on the tombstone intervals // Tombstone segments are only published, but since they donot have any data, they are not pushed by the // SegmentGeneratorFrameProcessorFactory. We can get the tombstone segment ids published by taking a set @@ -1245,7 +1250,7 @@ public class MSQTestBase extends BaseCalciteQueryTest // Made the visibility public to aid adding ut's easily with minimum parameters to set. @Nullable - public Pair>> runQueryWithResult() + public Pair, List>> runQueryWithResult() { readyToRun(); Preconditions.checkArgument(sql != null, "sql cannot be null"); @@ -1280,18 +1285,16 @@ public class MSQTestBase extends BaseCalciteQueryTest if (payload.getStatus().getErrorReport() != null) { throw new ISE("Query %s failed due to %s", sql, payload.getStatus().getErrorReport().toString()); } else { - Optional>> rowSignatureListPair = getSignatureWithRows(payload.getResults()); - if (!rowSignatureListPair.isPresent()) { + final List rows = getRows(payload.getResults()); + if (rows == null) { throw new ISE("Query successful but no results found"); } - log.info("found row signature %s", rowSignatureListPair.get().lhs); - log.info(rowSignatureListPair.get().rhs.stream() - .map(row -> Arrays.toString(row)) - .collect(Collectors.joining("\n"))); + log.info("found row signature %s", payload.getResults().getSignature()); + log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n"))); - MSQSpec spec = indexingServiceClient.getQuerySpecForTask(controllerId); + final MSQSpec spec = indexingServiceClient.getQuerySpecForTask(controllerId); log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec)); - return new Pair<>(spec, rowSignatureListPair.get()); + return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), rows)); } } catch (Exception e) { @@ -1308,7 +1311,7 @@ public class MSQTestBase extends BaseCalciteQueryTest Preconditions.checkArgument(expectedResultRows != null, "Result rows cannot be null"); Preconditions.checkArgument(expectedRowSignature != null, "Row signature cannot be null"); Preconditions.checkArgument(expectedMSQSpec != null, "MultiStageQuery Query spec not "); - Pair>> specAndResults = runQueryWithResult(); + Pair, List>> specAndResults = runQueryWithResult(); if (specAndResults == null) { // A fault was expected and the assertion has been done in the runQueryWithResult return; @@ -1327,4 +1330,18 @@ public class MSQTestBase extends BaseCalciteQueryTest } } } + + private static List resultSignatureFromRowSignature(final RowSignature signature) + { + final List retVal = new ArrayList<>(signature.size()); + for (int i = 0; i < signature.size(); i++) { + retVal.add( + new MSQResultsReport.ColumnAndType( + signature.getColumnName(i), + signature.getColumnType(i).orElse(null) + ) + ); + } + return retVal; + } } diff --git a/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java b/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java index 9051e1e138b..7525cb9d874 100644 --- a/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java +++ b/integration-tests/src/main/java/org/apache/druid/testing/utils/MsqTestQueryHelper.java @@ -37,7 +37,6 @@ import org.apache.druid.msq.indexing.report.MSQResultsReport; import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; import org.apache.druid.msq.sql.SqlTaskStatus; -import org.apache.druid.segment.column.RowSignature; import org.apache.druid.sql.http.SqlQuery; import org.apache.druid.testing.IntegrationTestingConfig; import org.apache.druid.testing.clients.SqlResourceTestClient; @@ -203,13 +202,13 @@ public class MsqTestQueryHelper extends AbstractTestQueryHelper> actualResults = new ArrayList<>(); Yielder yielder = resultsReport.getResultYielder(); - RowSignature rowSignature = resultsReport.getSignature(); + List rowSignature = resultsReport.getSignature(); while (!yielder.isDone()) { Object[] row = yielder.get(); Map rowWithFieldNames = new LinkedHashMap<>(); for (int i = 0; i < row.length; ++i) { - rowWithFieldNames.put(rowSignature.getColumnName(i), row[i]); + rowWithFieldNames.put(rowSignature.get(i).getName(), row[i]); } actualResults.add(rowWithFieldNames); yielder = yielder.next(null); diff --git a/processing/src/main/java/org/apache/druid/utils/CollectionUtils.java b/processing/src/main/java/org/apache/druid/utils/CollectionUtils.java index d4e1adf830c..15c11ebea55 100644 --- a/processing/src/main/java/org/apache/druid/utils/CollectionUtils.java +++ b/processing/src/main/java/org/apache/druid/utils/CollectionUtils.java @@ -33,6 +33,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.Spliterator; import java.util.TreeSet; @@ -129,6 +130,7 @@ public final class CollectionUtils * can be replaced with Guava's implementation once Druid has upgraded its Guava dependency to a sufficient version. * * @param expectedSize the expected size of the LinkedHashMap + * * @return LinkedHashMap object with appropriate size based on callers expectedSize */ @SuppressForbidden(reason = "java.util.LinkedHashMap#(int)") @@ -184,6 +186,27 @@ public final class CollectionUtils return result; } + /** + * Like {@link Iterables#getOnlyElement(Iterable)}, but allows a customizable error message. + */ + public static , X extends Throwable> T getOnlyElement( + final I iterable, + final Function exceptionSupplier + ) throws X + { + final Iterator iterator = iterable.iterator(); + try { + final T object = iterator.next(); + if (iterator.hasNext()) { + throw exceptionSupplier.apply(iterable); + } + return object; + } + catch (NoSuchElementException e) { + throw exceptionSupplier.apply(iterable); + } + } + private CollectionUtils() { } diff --git a/processing/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java b/processing/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java index 522b9e0ade7..d84a56bb613 100644 --- a/processing/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java +++ b/processing/src/test/java/org/apache/druid/utils/CollectionUtilsTest.java @@ -19,13 +19,18 @@ package org.apache.druid.utils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.apache.druid.java.util.common.ISE; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import java.util.Collections; import java.util.Set; -import static org.junit.Assert.assertEquals; - public class CollectionUtilsTest { // When Java 9 is allowed, use Set.of(). @@ -37,28 +42,57 @@ public class CollectionUtilsTest @Test public void testSubtract() { - assertEquals(empty, CollectionUtils.subtract(empty, empty)); - assertEquals(abc, CollectionUtils.subtract(abc, empty)); - assertEquals(empty, CollectionUtils.subtract(abc, abc)); - assertEquals(abc, CollectionUtils.subtract(abc, efg)); - assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd)); + Assert.assertEquals(empty, CollectionUtils.subtract(empty, empty)); + Assert.assertEquals(abc, CollectionUtils.subtract(abc, empty)); + Assert.assertEquals(empty, CollectionUtils.subtract(abc, abc)); + Assert.assertEquals(abc, CollectionUtils.subtract(abc, efg)); + Assert.assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd)); } @Test public void testIntersect() { - assertEquals(empty, CollectionUtils.intersect(empty, empty)); - assertEquals(abc, CollectionUtils.intersect(abc, abc)); - assertEquals(empty, CollectionUtils.intersect(abc, efg)); - assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd)); + Assert.assertEquals(empty, CollectionUtils.intersect(empty, empty)); + Assert.assertEquals(abc, CollectionUtils.intersect(abc, abc)); + Assert.assertEquals(empty, CollectionUtils.intersect(abc, efg)); + Assert.assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd)); } @Test public void testUnion() { - assertEquals(empty, CollectionUtils.union(empty, empty)); - assertEquals(abc, CollectionUtils.union(abc, abc)); - assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg)); - assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd)); + Assert.assertEquals(empty, CollectionUtils.union(empty, empty)); + Assert.assertEquals(abc, CollectionUtils.union(abc, abc)); + Assert.assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg)); + Assert.assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd)); + } + + @Test + public void testGetOnlyElement_empty() + { + final IllegalStateException e = Assert.assertThrows( + IllegalStateException.class, + () -> CollectionUtils.getOnlyElement(Collections.emptyList(), xs -> new ISE("oops")) + ); + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("oops"))); + } + + @Test + public void testGetOnlyElement_one() + { + Assert.assertEquals( + "a", + CollectionUtils.getOnlyElement(Collections.singletonList("a"), xs -> new ISE("oops")) + ); + } + + @Test + public void testGetOnlyElement_two() + { + final IllegalStateException e = Assert.assertThrows( + IllegalStateException.class, + () -> CollectionUtils.getOnlyElement(ImmutableList.of("a", "b"), xs -> new ISE("oops")) + ); + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("oops"))); } }