MSQ: Support multiple result columns with the same name. (#14025)

* MSQ: Support multiple result columns with the same name.

This is allowed in SQL, and is supported by the regular SQL endpoint.
We retain a validation that INSERT ... SELECT does not allow multiple
columns with the same name, because column names in segments must be
unique.
This commit is contained in:
Gian Merlino 2023-04-12 22:39:39 -07:00 committed by GitHub
parent f86ea5cbc4
commit 81074411a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 436 additions and 133 deletions

View File

@ -36,6 +36,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.data.input.StringTuple;
@ -186,6 +187,7 @@ import org.apache.druid.timeline.partition.DimensionRangeShardSpec;
import org.apache.druid.timeline.partition.NumberedPartialShardSpec;
import org.apache.druid.timeline.partition.NumberedShardSpec;
import org.apache.druid.timeline.partition.ShardSpec;
import org.apache.druid.utils.CollectionUtils;
import org.joda.time.DateTime;
import org.joda.time.Interval;
@ -1435,7 +1437,7 @@ public class ControllerImpl implements Controller
final List<Object[]> retVal = new ArrayList<>();
while (!cursor.isDone()) {
final Object[] row = new Object[columnMappings.getMappings().size()];
final Object[] row = new Object[columnMappings.size()];
for (int i = 0; i < row.length; i++) {
row[i] = selectors.get(i).getObject();
}
@ -1499,6 +1501,8 @@ public class ControllerImpl implements Controller
)
{
final MSQTuningConfig tuningConfig = querySpec.getTuningConfig();
final ColumnMappings columnMappings = querySpec.getColumnMappings();
final Query<?> queryToPlan;
final ShuffleSpecFactory shuffleSpecFactory;
if (MSQControllerTask.isIngestion(querySpec)) {
@ -1508,24 +1512,32 @@ public class ControllerImpl implements Controller
tuningConfig.getRowsPerSegment(),
aggregate
);
} else if (querySpec.getDestination() instanceof TaskReportMSQDestination) {
shuffleSpecFactory = ShuffleSpecFactories.singlePartition();
} else {
throw new ISE("Unsupported destination [%s]", querySpec.getDestination());
if (!columnMappings.hasUniqueOutputColumnNames()) {
// We do not expect to hit this case in production, because the SQL validator checks that column names
// are unique for INSERT and REPLACE statements (i.e. anything where MSQControllerTask.isIngestion would
// be true). This check is here as defensive programming.
throw new ISE("Column names are not unique: [%s]", columnMappings.getOutputColumnNames());
}
final Query<?> queryToPlan;
if (querySpec.getColumnMappings().hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
// We know there's a single time column, because we've checked columnMappings.hasUniqueOutputColumnNames().
final int timeColumn = columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME).getInt(0);
queryToPlan = querySpec.getQuery().withOverriddenContext(
ImmutableMap.of(
QueryKitUtils.CTX_TIME_COLUMN_NAME,
querySpec.getColumnMappings().getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME)
columnMappings.getQueryColumnName(timeColumn)
)
);
} else {
queryToPlan = querySpec.getQuery();
}
} else if (querySpec.getDestination() instanceof TaskReportMSQDestination) {
shuffleSpecFactory = ShuffleSpecFactories.singlePartition();
queryToPlan = querySpec.getQuery();
} else {
throw new ISE("Unsupported destination [%s]", querySpec.getDestination());
}
final QueryDefinition queryDef;
@ -1550,7 +1562,6 @@ public class ControllerImpl implements Controller
if (MSQControllerTask.isIngestion(querySpec)) {
final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature();
final ClusterBy queryClusterBy = queryDef.getFinalStageDefinition().getClusterBy();
final ColumnMappings columnMappings = querySpec.getColumnMappings();
// Find the stage that provides shuffled input to the final segment-generation stage.
StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition();
@ -1679,8 +1690,10 @@ public class ControllerImpl implements Controller
*/
private static boolean timeIsGroupByDimension(GroupByQuery groupByQuery, ColumnMappings columnMappings)
{
if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME);
final IntList positions = columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME);
if (positions.size() == 1) {
final String queryTimeColumn = columnMappings.getQueryColumnName(positions.getInt(0));
return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
} else {
return false;
@ -1740,7 +1753,7 @@ public class ControllerImpl implements Controller
for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) {
final KeyColumn column = clusterByColumns.get(i);
final List<String> outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName());
final IntList outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName());
// DimensionRangeShardSpec only handles ascending order.
if (column.order() != KeyOrder.ASCENDING) {
@ -1759,7 +1772,7 @@ public class ControllerImpl implements Controller
return Collections.emptyList();
}
shardColumns.add(outputColumns.get(0));
shardColumns.add(columnMappings.getOutputColumnName(outputColumns.getInt(0)));
}
return shardColumns;
@ -1830,7 +1843,10 @@ public class ControllerImpl implements Controller
throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName()));
}
outputColumnsInOrder.addAll(columnMappings.getOutputColumnsForQueryColumn(clusterByColumn.columnName()));
final IntList outputColumns = columnMappings.getOutputColumnsForQueryColumn(clusterByColumn.columnName());
for (final int outputColumn : outputColumns) {
outputColumnsInOrder.add(columnMappings.getOutputColumnName(outputColumn));
}
}
// Then all other columns.
@ -1841,13 +1857,17 @@ public class ControllerImpl implements Controller
if (isRollupQuery) {
// Populate aggregators from the native query when doing an ingest in rollup mode.
for (AggregatorFactory aggregatorFactory : ((GroupByQuery) query).getAggregatorSpecs()) {
String outputColumn = Iterables.getOnlyElement(columnMappings.getOutputColumnsForQueryColumn(aggregatorFactory.getName()));
if (outputColumnAggregatorFactories.containsKey(outputColumn)) {
throw new ISE("There can only be one aggregator factory for column [%s].", outputColumn);
final int outputColumn = CollectionUtils.getOnlyElement(
columnMappings.getOutputColumnsForQueryColumn(aggregatorFactory.getName()),
xs -> new ISE("Expected single output for query column[%s] but got[%s]", aggregatorFactory.getName(), xs)
);
final String outputColumnName = columnMappings.getOutputColumnName(outputColumn);
if (outputColumnAggregatorFactories.containsKey(outputColumnName)) {
throw new ISE("There can only be one aggregation for column [%s].", outputColumn);
} else {
outputColumnAggregatorFactories.put(
outputColumn,
aggregatorFactory.withName(outputColumn).getCombiningFactory()
outputColumnName,
aggregatorFactory.withName(outputColumnName).getCombiningFactory()
);
}
}
@ -1856,13 +1876,19 @@ public class ControllerImpl implements Controller
// Each column can be of either time, dimension, aggregator. For this method. we can ignore the time column.
// For non-complex columns, If the aggregator factory of the column is not available, we treat the column as
// a dimension. For complex columns, certains hacks are in place.
for (final String outputColumn : outputColumnsInOrder) {
final String queryColumn = columnMappings.getQueryColumnForOutputColumn(outputColumn);
for (final String outputColumnName : outputColumnsInOrder) {
// CollectionUtils.getOnlyElement because this method is only called during ingestion, where we require
// that output names be unique.
final int outputColumn = CollectionUtils.getOnlyElement(
columnMappings.getOutputColumnsByName(outputColumnName),
xs -> new ISE("Expected single output column for name [%s], but got [%s]", outputColumnName, xs)
);
final String queryColumn = columnMappings.getQueryColumnName(outputColumn);
final ColumnType type =
querySignature.getColumnType(queryColumn)
.orElseThrow(() -> new ISE("No type for column [%s]", outputColumn));
.orElseThrow(() -> new ISE("No type for column [%s]", outputColumnName));
if (!outputColumn.equals(ColumnHolder.TIME_COLUMN_NAME)) {
if (!outputColumnName.equals(ColumnHolder.TIME_COLUMN_NAME)) {
if (!type.is(ValueType.COMPLEX)) {
// non complex columns
@ -1870,21 +1896,21 @@ public class ControllerImpl implements Controller
dimensions,
aggregators,
outputColumnAggregatorFactories,
outputColumn,
outputColumnName,
type
);
} else {
// complex columns only
if (DimensionHandlerUtils.DIMENSION_HANDLER_PROVIDERS.containsKey(type.getComplexTypeName())) {
dimensions.add(DimensionSchemaUtils.createDimensionSchema(outputColumn, type));
dimensions.add(DimensionSchemaUtils.createDimensionSchema(outputColumnName, type));
} else if (!isRollupQuery) {
aggregators.add(new PassthroughAggregatorFactory(outputColumn, type.getComplexTypeName()));
aggregators.add(new PassthroughAggregatorFactory(outputColumnName, type.getComplexTypeName()));
} else {
populateDimensionsAndAggregators(
dimensions,
aggregators,
outputColumnAggregatorFactories,
outputColumn,
outputColumnName,
type
);
}
@ -1972,12 +1998,14 @@ public class ControllerImpl implements Controller
)
{
final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature();
final RowSignature.Builder mappedSignature = RowSignature.builder();
final ImmutableList.Builder<MSQResultsReport.ColumnAndType> mappedSignature = ImmutableList.builder();
for (final ColumnMapping mapping : columnMappings.getMappings()) {
mappedSignature.add(
new MSQResultsReport.ColumnAndType(
mapping.getOutputColumn(),
querySignature.getColumnType(mapping.getQueryColumn()).orElse(null)
)
);
}

View File

@ -22,12 +22,12 @@ package org.apache.druid.msq.indexing;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntLists;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.RowSignature;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@ -36,23 +36,32 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Maps column names from {@link MSQSpec#getQuery()} to output names desired by the user, in the order
* desired by the user.
*
* The {@link MSQSpec#getQuery()} is translated by {@link org.apache.druid.msq.querykit.QueryKit} into
* a {@link org.apache.druid.msq.kernel.QueryDefinition}. So, this class also represents mappings from
* {@link org.apache.druid.msq.kernel.QueryDefinition#getFinalStageDefinition()} into the output names desired
* by the user.
*/
public class ColumnMappings
{
private final List<ColumnMapping> mappings;
private final Map<String, String> outputToQueryColumnMap;
private final Map<String, List<String>> queryToOutputColumnsMap;
private final Map<String, IntList> outputColumnNameToPositionMap;
private final Map<String, IntList> queryColumnNameToPositionMap;
@JsonCreator
public ColumnMappings(final List<ColumnMapping> mappings)
{
this.mappings = validateNoDuplicateOutputColumns(Preconditions.checkNotNull(mappings, "mappings"));
this.outputToQueryColumnMap = new HashMap<>();
this.queryToOutputColumnsMap = new HashMap<>();
this.mappings = Preconditions.checkNotNull(mappings, "mappings");
this.outputColumnNameToPositionMap = new HashMap<>();
this.queryColumnNameToPositionMap = new HashMap<>();
for (final ColumnMapping mapping : mappings) {
outputToQueryColumnMap.put(mapping.getOutputColumn(), mapping.getQueryColumn());
queryToOutputColumnsMap.computeIfAbsent(mapping.getQueryColumn(), k -> new ArrayList<>())
.add(mapping.getOutputColumn());
for (int i = 0; i < mappings.size(); i++) {
final ColumnMapping mapping = mappings.get(i);
outputColumnNameToPositionMap.computeIfAbsent(mapping.getOutputColumn(), k -> new IntArrayList()).add(i);
queryColumnNameToPositionMap.computeIfAbsent(mapping.getQueryColumn(), k -> new IntArrayList()).add(i);
}
}
@ -66,34 +75,95 @@ public class ColumnMappings
);
}
/**
* Number of output columns.
*/
public int size()
{
return mappings.size();
}
/**
* All output column names, in order. Some names may appear more than once, unless
* {@link #hasUniqueOutputColumnNames()} is true.
*/
public List<String> getOutputColumnNames()
{
return mappings.stream().map(ColumnMapping::getOutputColumn).collect(Collectors.toList());
}
public boolean hasOutputColumn(final String columnName)
/**
* Whether output column names from {@link #getOutputColumnNames()} are all unique.
*/
public boolean hasUniqueOutputColumnNames()
{
return outputToQueryColumnMap.containsKey(columnName);
}
final Set<String> encountered = new HashSet<>();
public String getQueryColumnForOutputColumn(final String outputColumn)
{
final String queryColumn = outputToQueryColumnMap.get(outputColumn);
if (queryColumn != null) {
return queryColumn;
} else {
throw new IAE("No such output column [%s]", outputColumn);
for (final ColumnMapping mapping : mappings) {
if (!encountered.add(mapping.getOutputColumn())) {
return false;
}
}
public List<String> getOutputColumnsForQueryColumn(final String queryColumn)
{
final List<String> outputColumns = queryToOutputColumnsMap.get(queryColumn);
if (outputColumns != null) {
return outputColumns;
} else {
return Collections.emptyList();
return true;
}
/**
* Whether a particular output column name exists.
*/
public boolean hasOutputColumn(final String outputColumnName)
{
return outputColumnNameToPositionMap.containsKey(outputColumnName);
}
/**
* Query column name for a particular output column position.
*
* @throws IllegalArgumentException if the output column position is out of range
*/
public String getQueryColumnName(final int outputColumn)
{
if (outputColumn < 0 || outputColumn >= mappings.size()) {
throw new IAE("Output column position[%d] out of range", outputColumn);
}
return mappings.get(outputColumn).getQueryColumn();
}
/**
* Output column name for a particular output column position.
*
* @throws IllegalArgumentException if the output column position is out of range
*/
public String getOutputColumnName(final int outputColumn)
{
if (outputColumn < 0 || outputColumn >= mappings.size()) {
throw new IAE("Output column position[%d] out of range", outputColumn);
}
return mappings.get(outputColumn).getOutputColumn();
}
/**
* Output column positions for a particular output column name.
*/
public IntList getOutputColumnsByName(final String outputColumnName)
{
return outputColumnNameToPositionMap.getOrDefault(outputColumnName, IntLists.emptyList());
}
/**
* Output column positions for a particular query column name.
*/
public IntList getOutputColumnsForQueryColumn(final String queryColumnName)
{
final IntList outputColumnPositions = queryColumnNameToPositionMap.get(queryColumnName);
if (outputColumnPositions == null) {
return IntLists.emptyList();
}
return outputColumnPositions;
}
@JsonValue
@ -128,17 +198,4 @@ public class ColumnMappings
"mappings=" + mappings +
'}';
}
private static List<ColumnMapping> validateNoDuplicateOutputColumns(final List<ColumnMapping> mappings)
{
final Set<String> encountered = new HashSet<>();
for (final ColumnMapping mapping : mappings) {
if (!encountered.add(mapping.getOutputColumn())) {
throw new ISE("Duplicate output column [%s]", mapping.getOutputColumn());
}
}
return mappings;
}
}

View File

@ -26,20 +26,25 @@ import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.column.ColumnType;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Objects;
public class MSQResultsReport
{
private final RowSignature signature;
/**
* Like {@link org.apache.druid.segment.column.RowSignature}, but allows duplicate column names for compatibility
* with SQL (which also allows duplicate column names in query results).
*/
private final List<ColumnAndType> signature;
@Nullable
private final List<String> sqlTypeNames;
private final Yielder<Object[]> resultYielder;
public MSQResultsReport(
final RowSignature signature,
final List<ColumnAndType> signature,
@Nullable final List<String> sqlTypeNames,
final Yielder<Object[]> resultYielder
)
@ -54,7 +59,7 @@ public class MSQResultsReport
*/
@JsonCreator
static MSQResultsReport fromJson(
@JsonProperty("signature") final RowSignature signature,
@JsonProperty("signature") final List<ColumnAndType> signature,
@JsonProperty("sqlTypeNames") @Nullable final List<String> sqlTypeNames,
@JsonProperty("results") final List<Object[]> results
)
@ -63,7 +68,7 @@ public class MSQResultsReport
}
@JsonProperty("signature")
public RowSignature getSignature()
public List<ColumnAndType> getSignature()
{
return signature;
}
@ -81,4 +86,58 @@ public class MSQResultsReport
{
return resultYielder;
}
public static class ColumnAndType
{
private final String name;
private final ColumnType type;
@JsonCreator
public ColumnAndType(
@JsonProperty("name") String name,
@JsonProperty("type") ColumnType type
)
{
this.name = name;
this.type = type;
}
@JsonProperty
public String getName()
{
return name;
}
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public ColumnType getType()
{
return type;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ColumnAndType that = (ColumnAndType) o;
return Objects.equals(name, that.name) && Objects.equals(type, that.type);
}
@Override
public int hashCode()
{
return Objects.hash(name, type);
}
@Override
public String toString()
{
return name + ":" + type;
}
}
}

View File

@ -183,9 +183,6 @@ public class MSQTaskQueryMaker implements QueryMaker
final List<ColumnMapping> columnMappings = new ArrayList<>();
for (final Pair<Integer, String> entry : fieldMapping) {
// Note: SQL generally allows output columns to be duplicates, but MSQTaskSqlEngine.validateNoDuplicateAliases
// will prevent duplicate output columns from appearing here. So no need to worry about it.
final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey());
final String outputColumns = entry.getValue();

View File

@ -173,8 +173,6 @@ public class MSQTaskSqlEngine implements SqlEngine
final PlannerContext plannerContext
) throws ValidationException
{
validateNoDuplicateAliases(fieldMappings);
if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) {
throw new ValidationException(
StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
@ -248,7 +246,8 @@ public class MSQTaskSqlEngine implements SqlEngine
}
/**
* SQL allows multiple output columns with the same name, but multi-stage queries doesn't.
* SQL allows multiple output columns with the same name. However, we don't allow this for INSERT or REPLACE
* queries, because we use these output names to generate columns in segments. They must be unique.
*/
private static void validateNoDuplicateAliases(final List<Pair<Integer, String>> fieldMappings)
throws ValidationException
@ -257,7 +256,7 @@ public class MSQTaskSqlEngine implements SqlEngine
for (final Pair<Integer, String> field : fieldMappings) {
if (!aliasesSeen.add(field.right)) {
throw new ValidationException("Duplicate field in SELECT: " + field.right);
throw new ValidationException("Duplicate field in SELECT: [" + field.right + "]");
}
}
}

View File

@ -878,6 +878,30 @@ public class MSQInsertTest extends MSQTestBase
.verifyResults();
}
@Test
public void testInsertDuplicateColumnNames()
{
testIngestQuery()
.setSql(" insert into foo1 SELECT\n"
+ " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n"
+ " namespace,\n"
+ " \"user\" AS namespace\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [\"ignored\"],\"type\":\"local\"}',\n"
+ " '{\"type\": \"json\"}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"namespace\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}, {\"name\": \"__bucket\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") PARTITIONED by day")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"Duplicate field in SELECT: [namespace]"))
))
.verifyPlanningErrors();
}
@Test
public void testInsertQueryWithInvalidSubtaskCount()
{

View File

@ -35,6 +35,7 @@ import org.apache.druid.msq.indexing.ColumnMappings;
import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.test.CounterSnapshotMatcher;
import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestFileUtils;
@ -245,6 +246,73 @@ public class MSQSelectTest extends MSQTestBase
.verifyResults();
}
@Test
public void testSelectOnFooDuplicateColumnNames()
{
// Duplicate column names are OK in SELECT statements.
final RowSignature expectedScanSignature =
RowSignature.builder()
.add("cnt", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.build();
final ColumnMappings expectedColumnMappings = new ColumnMappings(
ImmutableList.of(
new ColumnMapping("cnt", "x"),
new ColumnMapping("dim1", "x")
)
);
final List<MSQResultsReport.ColumnAndType> expectedOutputSignature =
ImmutableList.of(
new MSQResultsReport.ColumnAndType("x", ColumnType.LONG),
new MSQResultsReport.ColumnAndType("x", ColumnType.STRING)
);
testSelectQuery()
.setSql("select cnt AS x, dim1 AS x from foo")
.setExpectedMSQSpec(
MSQSpec.builder()
.query(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("cnt", "dim1")
.context(defaultScanQueryContext(context, expectedScanSignature))
.build()
)
.columnMappings(expectedColumnMappings)
.tuningConfig(MSQTuningConfig.defaultConfig())
.build()
)
.setQueryContext(context)
.setExpectedRowSignature(expectedOutputSignature)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().totalFiles(1),
0, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "shuffle"
)
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L, !useDefault ? "" : null},
new Object[]{1L, "10.1"},
new Object[]{1L, "2"},
new Object[]{1L, "1"},
new Object[]{1L, "def"},
new Object[]{1L, "abc"}
)).verifyResults();
}
@Test
public void testSelectOnFooWhereMatchesNoSegments()
{

View File

@ -51,6 +51,7 @@ import org.junit.rules.TemporaryFolder;
import java.io.File;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -101,7 +102,7 @@ public class MSQTaskReportTest
),
new CounterSnapshotsTree(),
new MSQResultsReport(
RowSignature.builder().add("s", ColumnType.STRING).build(),
Collections.singletonList(new MSQResultsReport.ColumnAndType("s", ColumnType.STRING)),
ImmutableList.of("VARCHAR"),
Yielders.each(Sequences.simple(results))
)

View File

@ -20,17 +20,14 @@
package org.apache.druid.msq.test;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.msq.indexing.report.MSQTaskReport;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.QueryTestBuilder;
import org.apache.druid.sql.calcite.QueryTestRunner;
import org.junit.Assert;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
/**
@ -98,11 +95,11 @@ public class ExtractResultsFactory implements QueryTestRunner.QueryRunStepFactor
if (!payload.getStatus().getStatus().isComplete()) {
throw new ISE("Query task [%s] should have finished", taskId);
}
Optional<Pair<RowSignature, List<Object[]>>> signatureListPair = MSQTestBase.getSignatureWithRows(payload.getResults());
if (!signatureListPair.isPresent()) {
final List<Object[]> resultRows = MSQTestBase.getRows(payload.getResults());
if (resultRows == null) {
throw new ISE("Results report not present in the task's report payload");
}
extractedResults.add(results.withResults(signatureListPair.get().rhs));
extractedResults.add(results.withResults(resultRows));
}
}

View File

@ -179,7 +179,6 @@ import org.mockito.Mockito;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.Closeable;
import java.io.File;
import java.io.IOException;
@ -191,7 +190,6 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
@ -775,12 +773,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
);
}
public static Optional<Pair<RowSignature, List<Object[]>>> getSignatureWithRows(MSQResultsReport resultsReport)
@Nullable
public static List<Object[]> getRows(@Nullable MSQResultsReport resultsReport)
{
if (resultsReport == null) {
return Optional.empty();
return null;
} else {
RowSignature rowSignature = resultsReport.getSignature();
Yielder<Object[]> yielder = resultsReport.getResultYielder();
List<Object[]> rows = new ArrayList<>();
while (!yielder.isDone()) {
@ -794,7 +792,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
throw new ISE("Unable to get results from the report");
}
return Optional.of(new Pair<RowSignature, List<Object[]>>(rowSignature, rows));
return rows;
}
}
@ -802,7 +800,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
{
protected String sql = null;
protected Map<String, Object> queryContext = DEFAULT_MSQ_CONTEXT;
protected RowSignature expectedRowSignature = null;
protected List<MSQResultsReport.ColumnAndType> expectedRowSignature = null;
protected MSQSpec expectedMSQSpec = null;
protected MSQTuningConfig expectedTuningConfig = null;
protected Set<SegmentId> expectedSegments = null;
@ -829,10 +827,17 @@ public class MSQTestBase extends BaseCalciteQueryTest
return asBuilder();
}
public Builder setExpectedRowSignature(List<MSQResultsReport.ColumnAndType> expectedRowSignature)
{
Preconditions.checkArgument(!expectedRowSignature.isEmpty(), "Row signature cannot be empty");
this.expectedRowSignature = expectedRowSignature;
return asBuilder();
}
public Builder setExpectedRowSignature(RowSignature expectedRowSignature)
{
Preconditions.checkArgument(!expectedRowSignature.equals(RowSignature.empty()), "Row signature cannot be empty");
this.expectedRowSignature = expectedRowSignature;
this.expectedRowSignature = resultSignatureFromRowSignature(expectedRowSignature);
return asBuilder();
}
@ -1100,7 +1105,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
final StorageAdapter storageAdapter = new QueryableIndexStorageAdapter(queryableIndex);
// assert rowSignature
Assert.assertEquals(expectedRowSignature, storageAdapter.getRowSignature());
Assert.assertEquals(expectedRowSignature, resultSignatureFromRowSignature(storageAdapter.getRowSignature()));
// assert rollup
Assert.assertEquals(expectedRollUp, queryableIndex.getMetadata().isRollup());
@ -1245,7 +1250,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
// Made the visibility public to aid adding ut's easily with minimum parameters to set.
@Nullable
public Pair<MSQSpec, Pair<RowSignature, List<Object[]>>> runQueryWithResult()
public Pair<MSQSpec, Pair<List<MSQResultsReport.ColumnAndType>, List<Object[]>>> runQueryWithResult()
{
readyToRun();
Preconditions.checkArgument(sql != null, "sql cannot be null");
@ -1280,18 +1285,16 @@ public class MSQTestBase extends BaseCalciteQueryTest
if (payload.getStatus().getErrorReport() != null) {
throw new ISE("Query %s failed due to %s", sql, payload.getStatus().getErrorReport().toString());
} else {
Optional<Pair<RowSignature, List<Object[]>>> rowSignatureListPair = getSignatureWithRows(payload.getResults());
if (!rowSignatureListPair.isPresent()) {
final List<Object[]> rows = getRows(payload.getResults());
if (rows == null) {
throw new ISE("Query successful but no results found");
}
log.info("found row signature %s", rowSignatureListPair.get().lhs);
log.info(rowSignatureListPair.get().rhs.stream()
.map(row -> Arrays.toString(row))
.collect(Collectors.joining("\n")));
log.info("found row signature %s", payload.getResults().getSignature());
log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n")));
MSQSpec spec = indexingServiceClient.getQuerySpecForTask(controllerId);
final MSQSpec spec = indexingServiceClient.getQuerySpecForTask(controllerId);
log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec));
return new Pair<>(spec, rowSignatureListPair.get());
return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), rows));
}
}
catch (Exception e) {
@ -1308,7 +1311,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
Preconditions.checkArgument(expectedResultRows != null, "Result rows cannot be null");
Preconditions.checkArgument(expectedRowSignature != null, "Row signature cannot be null");
Preconditions.checkArgument(expectedMSQSpec != null, "MultiStageQuery Query spec not ");
Pair<MSQSpec, Pair<RowSignature, List<Object[]>>> specAndResults = runQueryWithResult();
Pair<MSQSpec, Pair<List<MSQResultsReport.ColumnAndType>, List<Object[]>>> specAndResults = runQueryWithResult();
if (specAndResults == null) { // A fault was expected and the assertion has been done in the runQueryWithResult
return;
@ -1327,4 +1330,18 @@ public class MSQTestBase extends BaseCalciteQueryTest
}
}
}
private static List<MSQResultsReport.ColumnAndType> resultSignatureFromRowSignature(final RowSignature signature)
{
final List<MSQResultsReport.ColumnAndType> retVal = new ArrayList<>(signature.size());
for (int i = 0; i < signature.size(); i++) {
retVal.add(
new MSQResultsReport.ColumnAndType(
signature.getColumnName(i),
signature.getColumnType(i).orElse(null)
)
);
}
return retVal;
}
}

View File

@ -37,7 +37,6 @@ import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.indexing.report.MSQTaskReport;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
import org.apache.druid.msq.sql.SqlTaskStatus;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.testing.IntegrationTestingConfig;
import org.apache.druid.testing.clients.SqlResourceTestClient;
@ -203,13 +202,13 @@ public class MsqTestQueryHelper extends AbstractTestQueryHelper<MsqQueryWithResu
List<Map<String, Object>> actualResults = new ArrayList<>();
Yielder<Object[]> yielder = resultsReport.getResultYielder();
RowSignature rowSignature = resultsReport.getSignature();
List<MSQResultsReport.ColumnAndType> rowSignature = resultsReport.getSignature();
while (!yielder.isDone()) {
Object[] row = yielder.get();
Map<String, Object> rowWithFieldNames = new LinkedHashMap<>();
for (int i = 0; i < row.length; ++i) {
rowWithFieldNames.put(rowSignature.getColumnName(i), row[i]);
rowWithFieldNames.put(rowSignature.get(i).getName(), row[i]);
}
actualResults.add(rowWithFieldNames);
yielder = yielder.next(null);

View File

@ -33,6 +33,7 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.Spliterator;
import java.util.TreeSet;
@ -129,6 +130,7 @@ public final class CollectionUtils
* can be replaced with Guava's implementation once Druid has upgraded its Guava dependency to a sufficient version.
*
* @param expectedSize the expected size of the LinkedHashMap
*
* @return LinkedHashMap object with appropriate size based on callers expectedSize
*/
@SuppressForbidden(reason = "java.util.LinkedHashMap#<init>(int)")
@ -184,6 +186,27 @@ public final class CollectionUtils
return result;
}
/**
* Like {@link Iterables#getOnlyElement(Iterable)}, but allows a customizable error message.
*/
public static <T, I extends Iterable<T>, X extends Throwable> T getOnlyElement(
final I iterable,
final Function<? super I, ? extends X> exceptionSupplier
) throws X
{
final Iterator<T> iterator = iterable.iterator();
try {
final T object = iterator.next();
if (iterator.hasNext()) {
throw exceptionSupplier.apply(iterable);
}
return object;
}
catch (NoSuchElementException e) {
throw exceptionSupplier.apply(iterable);
}
}
private CollectionUtils()
{
}

View File

@ -19,13 +19,18 @@
package org.apache.druid.utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.apache.druid.java.util.common.ISE;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import java.util.Collections;
import java.util.Set;
import static org.junit.Assert.assertEquals;
public class CollectionUtilsTest
{
// When Java 9 is allowed, use Set.of().
@ -37,28 +42,57 @@ public class CollectionUtilsTest
@Test
public void testSubtract()
{
assertEquals(empty, CollectionUtils.subtract(empty, empty));
assertEquals(abc, CollectionUtils.subtract(abc, empty));
assertEquals(empty, CollectionUtils.subtract(abc, abc));
assertEquals(abc, CollectionUtils.subtract(abc, efg));
assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd));
Assert.assertEquals(empty, CollectionUtils.subtract(empty, empty));
Assert.assertEquals(abc, CollectionUtils.subtract(abc, empty));
Assert.assertEquals(empty, CollectionUtils.subtract(abc, abc));
Assert.assertEquals(abc, CollectionUtils.subtract(abc, efg));
Assert.assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd));
}
@Test
public void testIntersect()
{
assertEquals(empty, CollectionUtils.intersect(empty, empty));
assertEquals(abc, CollectionUtils.intersect(abc, abc));
assertEquals(empty, CollectionUtils.intersect(abc, efg));
assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd));
Assert.assertEquals(empty, CollectionUtils.intersect(empty, empty));
Assert.assertEquals(abc, CollectionUtils.intersect(abc, abc));
Assert.assertEquals(empty, CollectionUtils.intersect(abc, efg));
Assert.assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd));
}
@Test
public void testUnion()
{
assertEquals(empty, CollectionUtils.union(empty, empty));
assertEquals(abc, CollectionUtils.union(abc, abc));
assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg));
assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd));
Assert.assertEquals(empty, CollectionUtils.union(empty, empty));
Assert.assertEquals(abc, CollectionUtils.union(abc, abc));
Assert.assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg));
Assert.assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd));
}
@Test
public void testGetOnlyElement_empty()
{
final IllegalStateException e = Assert.assertThrows(
IllegalStateException.class,
() -> CollectionUtils.getOnlyElement(Collections.emptyList(), xs -> new ISE("oops"))
);
MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("oops")));
}
@Test
public void testGetOnlyElement_one()
{
Assert.assertEquals(
"a",
CollectionUtils.getOnlyElement(Collections.singletonList("a"), xs -> new ISE("oops"))
);
}
@Test
public void testGetOnlyElement_two()
{
final IllegalStateException e = Assert.assertThrows(
IllegalStateException.class,
() -> CollectionUtils.getOnlyElement(ImmutableList.of("a", "b"), xs -> new ISE("oops"))
);
MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("oops")));
}
}