MSQ: Support multiple result columns with the same name. (#14025)

* MSQ: Support multiple result columns with the same name.

This is allowed in SQL, and is supported by the regular SQL endpoint.
We retain a validation that INSERT ... SELECT does not allow multiple
columns with the same name, because column names in segments must be
unique.
This commit is contained in:
Gian Merlino 2023-04-12 22:39:39 -07:00 committed by GitHub
parent f86ea5cbc4
commit 81074411a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 436 additions and 133 deletions

View File

@ -36,6 +36,7 @@ import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntArraySet; import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.data.input.StringTuple; import org.apache.druid.data.input.StringTuple;
@ -186,6 +187,7 @@ import org.apache.druid.timeline.partition.DimensionRangeShardSpec;
import org.apache.druid.timeline.partition.NumberedPartialShardSpec; import org.apache.druid.timeline.partition.NumberedPartialShardSpec;
import org.apache.druid.timeline.partition.NumberedShardSpec; import org.apache.druid.timeline.partition.NumberedShardSpec;
import org.apache.druid.timeline.partition.ShardSpec; import org.apache.druid.timeline.partition.ShardSpec;
import org.apache.druid.utils.CollectionUtils;
import org.joda.time.DateTime; import org.joda.time.DateTime;
import org.joda.time.Interval; import org.joda.time.Interval;
@ -1435,7 +1437,7 @@ public class ControllerImpl implements Controller
final List<Object[]> retVal = new ArrayList<>(); final List<Object[]> retVal = new ArrayList<>();
while (!cursor.isDone()) { while (!cursor.isDone()) {
final Object[] row = new Object[columnMappings.getMappings().size()]; final Object[] row = new Object[columnMappings.size()];
for (int i = 0; i < row.length; i++) { for (int i = 0; i < row.length; i++) {
row[i] = selectors.get(i).getObject(); row[i] = selectors.get(i).getObject();
} }
@ -1499,6 +1501,8 @@ public class ControllerImpl implements Controller
) )
{ {
final MSQTuningConfig tuningConfig = querySpec.getTuningConfig(); final MSQTuningConfig tuningConfig = querySpec.getTuningConfig();
final ColumnMappings columnMappings = querySpec.getColumnMappings();
final Query<?> queryToPlan;
final ShuffleSpecFactory shuffleSpecFactory; final ShuffleSpecFactory shuffleSpecFactory;
if (MSQControllerTask.isIngestion(querySpec)) { if (MSQControllerTask.isIngestion(querySpec)) {
@ -1508,24 +1512,32 @@ public class ControllerImpl implements Controller
tuningConfig.getRowsPerSegment(), tuningConfig.getRowsPerSegment(),
aggregate aggregate
); );
} else if (querySpec.getDestination() instanceof TaskReportMSQDestination) {
shuffleSpecFactory = ShuffleSpecFactories.singlePartition(); if (!columnMappings.hasUniqueOutputColumnNames()) {
} else { // We do not expect to hit this case in production, because the SQL validator checks that column names
throw new ISE("Unsupported destination [%s]", querySpec.getDestination()); // are unique for INSERT and REPLACE statements (i.e. anything where MSQControllerTask.isIngestion would
// be true). This check is here as defensive programming.
throw new ISE("Column names are not unique: [%s]", columnMappings.getOutputColumnNames());
} }
final Query<?> queryToPlan; if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) {
// We know there's a single time column, because we've checked columnMappings.hasUniqueOutputColumnNames().
if (querySpec.getColumnMappings().hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { final int timeColumn = columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME).getInt(0);
queryToPlan = querySpec.getQuery().withOverriddenContext( queryToPlan = querySpec.getQuery().withOverriddenContext(
ImmutableMap.of( ImmutableMap.of(
QueryKitUtils.CTX_TIME_COLUMN_NAME, QueryKitUtils.CTX_TIME_COLUMN_NAME,
querySpec.getColumnMappings().getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME) columnMappings.getQueryColumnName(timeColumn)
) )
); );
} else { } else {
queryToPlan = querySpec.getQuery(); queryToPlan = querySpec.getQuery();
} }
} else if (querySpec.getDestination() instanceof TaskReportMSQDestination) {
shuffleSpecFactory = ShuffleSpecFactories.singlePartition();
queryToPlan = querySpec.getQuery();
} else {
throw new ISE("Unsupported destination [%s]", querySpec.getDestination());
}
final QueryDefinition queryDef; final QueryDefinition queryDef;
@ -1550,7 +1562,6 @@ public class ControllerImpl implements Controller
if (MSQControllerTask.isIngestion(querySpec)) { if (MSQControllerTask.isIngestion(querySpec)) {
final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature(); final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature();
final ClusterBy queryClusterBy = queryDef.getFinalStageDefinition().getClusterBy(); final ClusterBy queryClusterBy = queryDef.getFinalStageDefinition().getClusterBy();
final ColumnMappings columnMappings = querySpec.getColumnMappings();
// Find the stage that provides shuffled input to the final segment-generation stage. // Find the stage that provides shuffled input to the final segment-generation stage.
StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition(); StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition();
@ -1679,8 +1690,10 @@ public class ControllerImpl implements Controller
*/ */
private static boolean timeIsGroupByDimension(GroupByQuery groupByQuery, ColumnMappings columnMappings) private static boolean timeIsGroupByDimension(GroupByQuery groupByQuery, ColumnMappings columnMappings)
{ {
if (columnMappings.hasOutputColumn(ColumnHolder.TIME_COLUMN_NAME)) { final IntList positions = columnMappings.getOutputColumnsByName(ColumnHolder.TIME_COLUMN_NAME);
final String queryTimeColumn = columnMappings.getQueryColumnForOutputColumn(ColumnHolder.TIME_COLUMN_NAME);
if (positions.size() == 1) {
final String queryTimeColumn = columnMappings.getQueryColumnName(positions.getInt(0));
return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD)); return queryTimeColumn.equals(groupByQuery.context().getString(GroupByQuery.CTX_TIMESTAMP_RESULT_FIELD));
} else { } else {
return false; return false;
@ -1740,7 +1753,7 @@ public class ControllerImpl implements Controller
for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) { for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) {
final KeyColumn column = clusterByColumns.get(i); final KeyColumn column = clusterByColumns.get(i);
final List<String> outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName()); final IntList outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName());
// DimensionRangeShardSpec only handles ascending order. // DimensionRangeShardSpec only handles ascending order.
if (column.order() != KeyOrder.ASCENDING) { if (column.order() != KeyOrder.ASCENDING) {
@ -1759,7 +1772,7 @@ public class ControllerImpl implements Controller
return Collections.emptyList(); return Collections.emptyList();
} }
shardColumns.add(outputColumns.get(0)); shardColumns.add(columnMappings.getOutputColumnName(outputColumns.getInt(0)));
} }
return shardColumns; return shardColumns;
@ -1830,7 +1843,10 @@ public class ControllerImpl implements Controller
throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName())); throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName()));
} }
outputColumnsInOrder.addAll(columnMappings.getOutputColumnsForQueryColumn(clusterByColumn.columnName())); final IntList outputColumns = columnMappings.getOutputColumnsForQueryColumn(clusterByColumn.columnName());
for (final int outputColumn : outputColumns) {
outputColumnsInOrder.add(columnMappings.getOutputColumnName(outputColumn));
}
} }
// Then all other columns. // Then all other columns.
@ -1841,13 +1857,17 @@ public class ControllerImpl implements Controller
if (isRollupQuery) { if (isRollupQuery) {
// Populate aggregators from the native query when doing an ingest in rollup mode. // Populate aggregators from the native query when doing an ingest in rollup mode.
for (AggregatorFactory aggregatorFactory : ((GroupByQuery) query).getAggregatorSpecs()) { for (AggregatorFactory aggregatorFactory : ((GroupByQuery) query).getAggregatorSpecs()) {
String outputColumn = Iterables.getOnlyElement(columnMappings.getOutputColumnsForQueryColumn(aggregatorFactory.getName())); final int outputColumn = CollectionUtils.getOnlyElement(
if (outputColumnAggregatorFactories.containsKey(outputColumn)) { columnMappings.getOutputColumnsForQueryColumn(aggregatorFactory.getName()),
throw new ISE("There can only be one aggregator factory for column [%s].", outputColumn); xs -> new ISE("Expected single output for query column[%s] but got[%s]", aggregatorFactory.getName(), xs)
);
final String outputColumnName = columnMappings.getOutputColumnName(outputColumn);
if (outputColumnAggregatorFactories.containsKey(outputColumnName)) {
throw new ISE("There can only be one aggregation for column [%s].", outputColumn);
} else { } else {
outputColumnAggregatorFactories.put( outputColumnAggregatorFactories.put(
outputColumn, outputColumnName,
aggregatorFactory.withName(outputColumn).getCombiningFactory() aggregatorFactory.withName(outputColumnName).getCombiningFactory()
); );
} }
} }
@ -1856,13 +1876,19 @@ public class ControllerImpl implements Controller
// Each column can be of either time, dimension, aggregator. For this method. we can ignore the time column. // Each column can be of either time, dimension, aggregator. For this method. we can ignore the time column.
// For non-complex columns, If the aggregator factory of the column is not available, we treat the column as // For non-complex columns, If the aggregator factory of the column is not available, we treat the column as
// a dimension. For complex columns, certains hacks are in place. // a dimension. For complex columns, certains hacks are in place.
for (final String outputColumn : outputColumnsInOrder) { for (final String outputColumnName : outputColumnsInOrder) {
final String queryColumn = columnMappings.getQueryColumnForOutputColumn(outputColumn); // CollectionUtils.getOnlyElement because this method is only called during ingestion, where we require
// that output names be unique.
final int outputColumn = CollectionUtils.getOnlyElement(
columnMappings.getOutputColumnsByName(outputColumnName),
xs -> new ISE("Expected single output column for name [%s], but got [%s]", outputColumnName, xs)
);
final String queryColumn = columnMappings.getQueryColumnName(outputColumn);
final ColumnType type = final ColumnType type =
querySignature.getColumnType(queryColumn) querySignature.getColumnType(queryColumn)
.orElseThrow(() -> new ISE("No type for column [%s]", outputColumn)); .orElseThrow(() -> new ISE("No type for column [%s]", outputColumnName));
if (!outputColumn.equals(ColumnHolder.TIME_COLUMN_NAME)) { if (!outputColumnName.equals(ColumnHolder.TIME_COLUMN_NAME)) {
if (!type.is(ValueType.COMPLEX)) { if (!type.is(ValueType.COMPLEX)) {
// non complex columns // non complex columns
@ -1870,21 +1896,21 @@ public class ControllerImpl implements Controller
dimensions, dimensions,
aggregators, aggregators,
outputColumnAggregatorFactories, outputColumnAggregatorFactories,
outputColumn, outputColumnName,
type type
); );
} else { } else {
// complex columns only // complex columns only
if (DimensionHandlerUtils.DIMENSION_HANDLER_PROVIDERS.containsKey(type.getComplexTypeName())) { if (DimensionHandlerUtils.DIMENSION_HANDLER_PROVIDERS.containsKey(type.getComplexTypeName())) {
dimensions.add(DimensionSchemaUtils.createDimensionSchema(outputColumn, type)); dimensions.add(DimensionSchemaUtils.createDimensionSchema(outputColumnName, type));
} else if (!isRollupQuery) { } else if (!isRollupQuery) {
aggregators.add(new PassthroughAggregatorFactory(outputColumn, type.getComplexTypeName())); aggregators.add(new PassthroughAggregatorFactory(outputColumnName, type.getComplexTypeName()));
} else { } else {
populateDimensionsAndAggregators( populateDimensionsAndAggregators(
dimensions, dimensions,
aggregators, aggregators,
outputColumnAggregatorFactories, outputColumnAggregatorFactories,
outputColumn, outputColumnName,
type type
); );
} }
@ -1972,12 +1998,14 @@ public class ControllerImpl implements Controller
) )
{ {
final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature(); final RowSignature querySignature = queryDef.getFinalStageDefinition().getSignature();
final RowSignature.Builder mappedSignature = RowSignature.builder(); final ImmutableList.Builder<MSQResultsReport.ColumnAndType> mappedSignature = ImmutableList.builder();
for (final ColumnMapping mapping : columnMappings.getMappings()) { for (final ColumnMapping mapping : columnMappings.getMappings()) {
mappedSignature.add( mappedSignature.add(
new MSQResultsReport.ColumnAndType(
mapping.getOutputColumn(), mapping.getOutputColumn(),
querySignature.getColumnType(mapping.getQueryColumn()).orElse(null) querySignature.getColumnType(mapping.getQueryColumn()).orElse(null)
)
); );
} }

View File

@ -22,12 +22,12 @@ package org.apache.druid.msq.indexing;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import it.unimi.dsi.fastutil.ints.IntLists;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
@ -36,23 +36,32 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/**
* Maps column names from {@link MSQSpec#getQuery()} to output names desired by the user, in the order
* desired by the user.
*
* The {@link MSQSpec#getQuery()} is translated by {@link org.apache.druid.msq.querykit.QueryKit} into
* a {@link org.apache.druid.msq.kernel.QueryDefinition}. So, this class also represents mappings from
* {@link org.apache.druid.msq.kernel.QueryDefinition#getFinalStageDefinition()} into the output names desired
* by the user.
*/
public class ColumnMappings public class ColumnMappings
{ {
private final List<ColumnMapping> mappings; private final List<ColumnMapping> mappings;
private final Map<String, String> outputToQueryColumnMap; private final Map<String, IntList> outputColumnNameToPositionMap;
private final Map<String, List<String>> queryToOutputColumnsMap; private final Map<String, IntList> queryColumnNameToPositionMap;
@JsonCreator @JsonCreator
public ColumnMappings(final List<ColumnMapping> mappings) public ColumnMappings(final List<ColumnMapping> mappings)
{ {
this.mappings = validateNoDuplicateOutputColumns(Preconditions.checkNotNull(mappings, "mappings")); this.mappings = Preconditions.checkNotNull(mappings, "mappings");
this.outputToQueryColumnMap = new HashMap<>(); this.outputColumnNameToPositionMap = new HashMap<>();
this.queryToOutputColumnsMap = new HashMap<>(); this.queryColumnNameToPositionMap = new HashMap<>();
for (final ColumnMapping mapping : mappings) { for (int i = 0; i < mappings.size(); i++) {
outputToQueryColumnMap.put(mapping.getOutputColumn(), mapping.getQueryColumn()); final ColumnMapping mapping = mappings.get(i);
queryToOutputColumnsMap.computeIfAbsent(mapping.getQueryColumn(), k -> new ArrayList<>()) outputColumnNameToPositionMap.computeIfAbsent(mapping.getOutputColumn(), k -> new IntArrayList()).add(i);
.add(mapping.getOutputColumn()); queryColumnNameToPositionMap.computeIfAbsent(mapping.getQueryColumn(), k -> new IntArrayList()).add(i);
} }
} }
@ -66,34 +75,95 @@ public class ColumnMappings
); );
} }
/**
* Number of output columns.
*/
public int size()
{
return mappings.size();
}
/**
* All output column names, in order. Some names may appear more than once, unless
* {@link #hasUniqueOutputColumnNames()} is true.
*/
public List<String> getOutputColumnNames() public List<String> getOutputColumnNames()
{ {
return mappings.stream().map(ColumnMapping::getOutputColumn).collect(Collectors.toList()); return mappings.stream().map(ColumnMapping::getOutputColumn).collect(Collectors.toList());
} }
public boolean hasOutputColumn(final String columnName) /**
* Whether output column names from {@link #getOutputColumnNames()} are all unique.
*/
public boolean hasUniqueOutputColumnNames()
{ {
return outputToQueryColumnMap.containsKey(columnName); final Set<String> encountered = new HashSet<>();
}
public String getQueryColumnForOutputColumn(final String outputColumn) for (final ColumnMapping mapping : mappings) {
{ if (!encountered.add(mapping.getOutputColumn())) {
final String queryColumn = outputToQueryColumnMap.get(outputColumn); return false;
if (queryColumn != null) {
return queryColumn;
} else {
throw new IAE("No such output column [%s]", outputColumn);
} }
} }
public List<String> getOutputColumnsForQueryColumn(final String queryColumn) return true;
{
final List<String> outputColumns = queryToOutputColumnsMap.get(queryColumn);
if (outputColumns != null) {
return outputColumns;
} else {
return Collections.emptyList();
} }
/**
* Whether a particular output column name exists.
*/
public boolean hasOutputColumn(final String outputColumnName)
{
return outputColumnNameToPositionMap.containsKey(outputColumnName);
}
/**
* Query column name for a particular output column position.
*
* @throws IllegalArgumentException if the output column position is out of range
*/
public String getQueryColumnName(final int outputColumn)
{
if (outputColumn < 0 || outputColumn >= mappings.size()) {
throw new IAE("Output column position[%d] out of range", outputColumn);
}
return mappings.get(outputColumn).getQueryColumn();
}
/**
* Output column name for a particular output column position.
*
* @throws IllegalArgumentException if the output column position is out of range
*/
public String getOutputColumnName(final int outputColumn)
{
if (outputColumn < 0 || outputColumn >= mappings.size()) {
throw new IAE("Output column position[%d] out of range", outputColumn);
}
return mappings.get(outputColumn).getOutputColumn();
}
/**
* Output column positions for a particular output column name.
*/
public IntList getOutputColumnsByName(final String outputColumnName)
{
return outputColumnNameToPositionMap.getOrDefault(outputColumnName, IntLists.emptyList());
}
/**
* Output column positions for a particular query column name.
*/
public IntList getOutputColumnsForQueryColumn(final String queryColumnName)
{
final IntList outputColumnPositions = queryColumnNameToPositionMap.get(queryColumnName);
if (outputColumnPositions == null) {
return IntLists.emptyList();
}
return outputColumnPositions;
} }
@JsonValue @JsonValue
@ -128,17 +198,4 @@ public class ColumnMappings
"mappings=" + mappings + "mappings=" + mappings +
'}'; '}';
} }
private static List<ColumnMapping> validateNoDuplicateOutputColumns(final List<ColumnMapping> mappings)
{
final Set<String> encountered = new HashSet<>();
for (final ColumnMapping mapping : mappings) {
if (!encountered.add(mapping.getOutputColumn())) {
throw new ISE("Duplicate output column [%s]", mapping.getOutputColumn());
}
}
return mappings;
}
} }

View File

@ -26,20 +26,25 @@ import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.guava.Sequences; import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.guava.Yielder; import org.apache.druid.java.util.common.guava.Yielder;
import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.ColumnType;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.List; import java.util.List;
import java.util.Objects;
public class MSQResultsReport public class MSQResultsReport
{ {
private final RowSignature signature; /**
* Like {@link org.apache.druid.segment.column.RowSignature}, but allows duplicate column names for compatibility
* with SQL (which also allows duplicate column names in query results).
*/
private final List<ColumnAndType> signature;
@Nullable @Nullable
private final List<String> sqlTypeNames; private final List<String> sqlTypeNames;
private final Yielder<Object[]> resultYielder; private final Yielder<Object[]> resultYielder;
public MSQResultsReport( public MSQResultsReport(
final RowSignature signature, final List<ColumnAndType> signature,
@Nullable final List<String> sqlTypeNames, @Nullable final List<String> sqlTypeNames,
final Yielder<Object[]> resultYielder final Yielder<Object[]> resultYielder
) )
@ -54,7 +59,7 @@ public class MSQResultsReport
*/ */
@JsonCreator @JsonCreator
static MSQResultsReport fromJson( static MSQResultsReport fromJson(
@JsonProperty("signature") final RowSignature signature, @JsonProperty("signature") final List<ColumnAndType> signature,
@JsonProperty("sqlTypeNames") @Nullable final List<String> sqlTypeNames, @JsonProperty("sqlTypeNames") @Nullable final List<String> sqlTypeNames,
@JsonProperty("results") final List<Object[]> results @JsonProperty("results") final List<Object[]> results
) )
@ -63,7 +68,7 @@ public class MSQResultsReport
} }
@JsonProperty("signature") @JsonProperty("signature")
public RowSignature getSignature() public List<ColumnAndType> getSignature()
{ {
return signature; return signature;
} }
@ -81,4 +86,58 @@ public class MSQResultsReport
{ {
return resultYielder; return resultYielder;
} }
public static class ColumnAndType
{
private final String name;
private final ColumnType type;
@JsonCreator
public ColumnAndType(
@JsonProperty("name") String name,
@JsonProperty("type") ColumnType type
)
{
this.name = name;
this.type = type;
}
@JsonProperty
public String getName()
{
return name;
}
@JsonProperty
@JsonInclude(JsonInclude.Include.NON_NULL)
public ColumnType getType()
{
return type;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ColumnAndType that = (ColumnAndType) o;
return Objects.equals(name, that.name) && Objects.equals(type, that.type);
}
@Override
public int hashCode()
{
return Objects.hash(name, type);
}
@Override
public String toString()
{
return name + ":" + type;
}
}
} }

View File

@ -183,9 +183,6 @@ public class MSQTaskQueryMaker implements QueryMaker
final List<ColumnMapping> columnMappings = new ArrayList<>(); final List<ColumnMapping> columnMappings = new ArrayList<>();
for (final Pair<Integer, String> entry : fieldMapping) { for (final Pair<Integer, String> entry : fieldMapping) {
// Note: SQL generally allows output columns to be duplicates, but MSQTaskSqlEngine.validateNoDuplicateAliases
// will prevent duplicate output columns from appearing here. So no need to worry about it.
final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey()); final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey());
final String outputColumns = entry.getValue(); final String outputColumns = entry.getValue();

View File

@ -173,8 +173,6 @@ public class MSQTaskSqlEngine implements SqlEngine
final PlannerContext plannerContext final PlannerContext plannerContext
) throws ValidationException ) throws ValidationException
{ {
validateNoDuplicateAliases(fieldMappings);
if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) { if (plannerContext.queryContext().containsKey(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) {
throw new ValidationException( throw new ValidationException(
StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY) StringUtils.format("Cannot use \"%s\" without INSERT", DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)
@ -248,7 +246,8 @@ public class MSQTaskSqlEngine implements SqlEngine
} }
/** /**
* SQL allows multiple output columns with the same name, but multi-stage queries doesn't. * SQL allows multiple output columns with the same name. However, we don't allow this for INSERT or REPLACE
* queries, because we use these output names to generate columns in segments. They must be unique.
*/ */
private static void validateNoDuplicateAliases(final List<Pair<Integer, String>> fieldMappings) private static void validateNoDuplicateAliases(final List<Pair<Integer, String>> fieldMappings)
throws ValidationException throws ValidationException
@ -257,7 +256,7 @@ public class MSQTaskSqlEngine implements SqlEngine
for (final Pair<Integer, String> field : fieldMappings) { for (final Pair<Integer, String> field : fieldMappings) {
if (!aliasesSeen.add(field.right)) { if (!aliasesSeen.add(field.right)) {
throw new ValidationException("Duplicate field in SELECT: " + field.right); throw new ValidationException("Duplicate field in SELECT: [" + field.right + "]");
} }
} }
} }

View File

@ -878,6 +878,30 @@ public class MSQInsertTest extends MSQTestBase
.verifyResults(); .verifyResults();
} }
@Test
public void testInsertDuplicateColumnNames()
{
testIngestQuery()
.setSql(" insert into foo1 SELECT\n"
+ " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n"
+ " namespace,\n"
+ " \"user\" AS namespace\n"
+ "FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{ \"files\": [\"ignored\"],\"type\":\"local\"}',\n"
+ " '{\"type\": \"json\"}',\n"
+ " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"namespace\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}, {\"name\": \"__bucket\", \"type\": \"string\"}]'\n"
+ " )\n"
+ ") PARTITIONED by day")
.setQueryContext(context)
.setExpectedValidationErrorMatcher(CoreMatchers.allOf(
CoreMatchers.instanceOf(SqlPlanningException.class),
ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString(
"Duplicate field in SELECT: [namespace]"))
))
.verifyPlanningErrors();
}
@Test @Test
public void testInsertQueryWithInvalidSubtaskCount() public void testInsertQueryWithInvalidSubtaskCount()
{ {

View File

@ -35,6 +35,7 @@ import org.apache.druid.msq.indexing.ColumnMappings;
import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQSpec;
import org.apache.druid.msq.indexing.MSQTuningConfig; import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault;
import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.test.CounterSnapshotMatcher; import org.apache.druid.msq.test.CounterSnapshotMatcher;
import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestFileUtils; import org.apache.druid.msq.test.MSQTestFileUtils;
@ -245,6 +246,73 @@ public class MSQSelectTest extends MSQTestBase
.verifyResults(); .verifyResults();
} }
@Test
public void testSelectOnFooDuplicateColumnNames()
{
// Duplicate column names are OK in SELECT statements.
final RowSignature expectedScanSignature =
RowSignature.builder()
.add("cnt", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.build();
final ColumnMappings expectedColumnMappings = new ColumnMappings(
ImmutableList.of(
new ColumnMapping("cnt", "x"),
new ColumnMapping("dim1", "x")
)
);
final List<MSQResultsReport.ColumnAndType> expectedOutputSignature =
ImmutableList.of(
new MSQResultsReport.ColumnAndType("x", ColumnType.LONG),
new MSQResultsReport.ColumnAndType("x", ColumnType.STRING)
);
testSelectQuery()
.setSql("select cnt AS x, dim1 AS x from foo")
.setExpectedMSQSpec(
MSQSpec.builder()
.query(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("cnt", "dim1")
.context(defaultScanQueryContext(context, expectedScanSignature))
.build()
)
.columnMappings(expectedColumnMappings)
.tuningConfig(MSQTuningConfig.defaultConfig())
.build()
)
.setQueryContext(context)
.setExpectedRowSignature(expectedOutputSignature)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().totalFiles(1),
0, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "shuffle"
)
.setExpectedResultRows(ImmutableList.of(
new Object[]{1L, !useDefault ? "" : null},
new Object[]{1L, "10.1"},
new Object[]{1L, "2"},
new Object[]{1L, "1"},
new Object[]{1L, "def"},
new Object[]{1L, "abc"}
)).verifyResults();
}
@Test @Test
public void testSelectOnFooWhereMatchesNoSegments() public void testSelectOnFooWhereMatchesNoSegments()
{ {

View File

@ -51,6 +51,7 @@ import org.junit.rules.TemporaryFolder;
import java.io.File; import java.io.File;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -101,7 +102,7 @@ public class MSQTaskReportTest
), ),
new CounterSnapshotsTree(), new CounterSnapshotsTree(),
new MSQResultsReport( new MSQResultsReport(
RowSignature.builder().add("s", ColumnType.STRING).build(), Collections.singletonList(new MSQResultsReport.ColumnAndType("s", ColumnType.STRING)),
ImmutableList.of("VARCHAR"), ImmutableList.of("VARCHAR"),
Yielders.each(Sequences.simple(results)) Yielders.each(Sequences.simple(results))
) )

View File

@ -20,17 +20,14 @@
package org.apache.druid.msq.test; package org.apache.druid.msq.test;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair;
import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReport;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.QueryTestBuilder; import org.apache.druid.sql.calcite.QueryTestBuilder;
import org.apache.druid.sql.calcite.QueryTestRunner; import org.apache.druid.sql.calcite.QueryTestRunner;
import org.junit.Assert; import org.junit.Assert;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.function.Supplier; import java.util.function.Supplier;
/** /**
@ -98,11 +95,11 @@ public class ExtractResultsFactory implements QueryTestRunner.QueryRunStepFactor
if (!payload.getStatus().getStatus().isComplete()) { if (!payload.getStatus().getStatus().isComplete()) {
throw new ISE("Query task [%s] should have finished", taskId); throw new ISE("Query task [%s] should have finished", taskId);
} }
Optional<Pair<RowSignature, List<Object[]>>> signatureListPair = MSQTestBase.getSignatureWithRows(payload.getResults()); final List<Object[]> resultRows = MSQTestBase.getRows(payload.getResults());
if (!signatureListPair.isPresent()) { if (resultRows == null) {
throw new ISE("Results report not present in the task's report payload"); throw new ISE("Results report not present in the task's report payload");
} }
extractedResults.add(results.withResults(signatureListPair.get().rhs)); extractedResults.add(results.withResults(resultRows));
} }
} }

View File

@ -179,7 +179,6 @@ import org.mockito.Mockito;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.Closeable; import java.io.Closeable;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -191,7 +190,6 @@ import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.SortedMap; import java.util.SortedMap;
import java.util.TreeMap; import java.util.TreeMap;
@ -775,12 +773,12 @@ public class MSQTestBase extends BaseCalciteQueryTest
); );
} }
public static Optional<Pair<RowSignature, List<Object[]>>> getSignatureWithRows(MSQResultsReport resultsReport) @Nullable
public static List<Object[]> getRows(@Nullable MSQResultsReport resultsReport)
{ {
if (resultsReport == null) { if (resultsReport == null) {
return Optional.empty(); return null;
} else { } else {
RowSignature rowSignature = resultsReport.getSignature();
Yielder<Object[]> yielder = resultsReport.getResultYielder(); Yielder<Object[]> yielder = resultsReport.getResultYielder();
List<Object[]> rows = new ArrayList<>(); List<Object[]> rows = new ArrayList<>();
while (!yielder.isDone()) { while (!yielder.isDone()) {
@ -794,7 +792,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
throw new ISE("Unable to get results from the report"); throw new ISE("Unable to get results from the report");
} }
return Optional.of(new Pair<RowSignature, List<Object[]>>(rowSignature, rows)); return rows;
} }
} }
@ -802,7 +800,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
{ {
protected String sql = null; protected String sql = null;
protected Map<String, Object> queryContext = DEFAULT_MSQ_CONTEXT; protected Map<String, Object> queryContext = DEFAULT_MSQ_CONTEXT;
protected RowSignature expectedRowSignature = null; protected List<MSQResultsReport.ColumnAndType> expectedRowSignature = null;
protected MSQSpec expectedMSQSpec = null; protected MSQSpec expectedMSQSpec = null;
protected MSQTuningConfig expectedTuningConfig = null; protected MSQTuningConfig expectedTuningConfig = null;
protected Set<SegmentId> expectedSegments = null; protected Set<SegmentId> expectedSegments = null;
@ -829,10 +827,17 @@ public class MSQTestBase extends BaseCalciteQueryTest
return asBuilder(); return asBuilder();
} }
public Builder setExpectedRowSignature(List<MSQResultsReport.ColumnAndType> expectedRowSignature)
{
Preconditions.checkArgument(!expectedRowSignature.isEmpty(), "Row signature cannot be empty");
this.expectedRowSignature = expectedRowSignature;
return asBuilder();
}
public Builder setExpectedRowSignature(RowSignature expectedRowSignature) public Builder setExpectedRowSignature(RowSignature expectedRowSignature)
{ {
Preconditions.checkArgument(!expectedRowSignature.equals(RowSignature.empty()), "Row signature cannot be empty"); Preconditions.checkArgument(!expectedRowSignature.equals(RowSignature.empty()), "Row signature cannot be empty");
this.expectedRowSignature = expectedRowSignature; this.expectedRowSignature = resultSignatureFromRowSignature(expectedRowSignature);
return asBuilder(); return asBuilder();
} }
@ -1100,7 +1105,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
final StorageAdapter storageAdapter = new QueryableIndexStorageAdapter(queryableIndex); final StorageAdapter storageAdapter = new QueryableIndexStorageAdapter(queryableIndex);
// assert rowSignature // assert rowSignature
Assert.assertEquals(expectedRowSignature, storageAdapter.getRowSignature()); Assert.assertEquals(expectedRowSignature, resultSignatureFromRowSignature(storageAdapter.getRowSignature()));
// assert rollup // assert rollup
Assert.assertEquals(expectedRollUp, queryableIndex.getMetadata().isRollup()); Assert.assertEquals(expectedRollUp, queryableIndex.getMetadata().isRollup());
@ -1245,7 +1250,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
// Made the visibility public to aid adding ut's easily with minimum parameters to set. // Made the visibility public to aid adding ut's easily with minimum parameters to set.
@Nullable @Nullable
public Pair<MSQSpec, Pair<RowSignature, List<Object[]>>> runQueryWithResult() public Pair<MSQSpec, Pair<List<MSQResultsReport.ColumnAndType>, List<Object[]>>> runQueryWithResult()
{ {
readyToRun(); readyToRun();
Preconditions.checkArgument(sql != null, "sql cannot be null"); Preconditions.checkArgument(sql != null, "sql cannot be null");
@ -1280,18 +1285,16 @@ public class MSQTestBase extends BaseCalciteQueryTest
if (payload.getStatus().getErrorReport() != null) { if (payload.getStatus().getErrorReport() != null) {
throw new ISE("Query %s failed due to %s", sql, payload.getStatus().getErrorReport().toString()); throw new ISE("Query %s failed due to %s", sql, payload.getStatus().getErrorReport().toString());
} else { } else {
Optional<Pair<RowSignature, List<Object[]>>> rowSignatureListPair = getSignatureWithRows(payload.getResults()); final List<Object[]> rows = getRows(payload.getResults());
if (!rowSignatureListPair.isPresent()) { if (rows == null) {
throw new ISE("Query successful but no results found"); throw new ISE("Query successful but no results found");
} }
log.info("found row signature %s", rowSignatureListPair.get().lhs); log.info("found row signature %s", payload.getResults().getSignature());
log.info(rowSignatureListPair.get().rhs.stream() log.info(rows.stream().map(Arrays::toString).collect(Collectors.joining("\n")));
.map(row -> Arrays.toString(row))
.collect(Collectors.joining("\n")));
MSQSpec spec = indexingServiceClient.getQuerySpecForTask(controllerId); final MSQSpec spec = indexingServiceClient.getQuerySpecForTask(controllerId);
log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec)); log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec));
return new Pair<>(spec, rowSignatureListPair.get()); return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), rows));
} }
} }
catch (Exception e) { catch (Exception e) {
@ -1308,7 +1311,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
Preconditions.checkArgument(expectedResultRows != null, "Result rows cannot be null"); Preconditions.checkArgument(expectedResultRows != null, "Result rows cannot be null");
Preconditions.checkArgument(expectedRowSignature != null, "Row signature cannot be null"); Preconditions.checkArgument(expectedRowSignature != null, "Row signature cannot be null");
Preconditions.checkArgument(expectedMSQSpec != null, "MultiStageQuery Query spec not "); Preconditions.checkArgument(expectedMSQSpec != null, "MultiStageQuery Query spec not ");
Pair<MSQSpec, Pair<RowSignature, List<Object[]>>> specAndResults = runQueryWithResult(); Pair<MSQSpec, Pair<List<MSQResultsReport.ColumnAndType>, List<Object[]>>> specAndResults = runQueryWithResult();
if (specAndResults == null) { // A fault was expected and the assertion has been done in the runQueryWithResult if (specAndResults == null) { // A fault was expected and the assertion has been done in the runQueryWithResult
return; return;
@ -1327,4 +1330,18 @@ public class MSQTestBase extends BaseCalciteQueryTest
} }
} }
} }
private static List<MSQResultsReport.ColumnAndType> resultSignatureFromRowSignature(final RowSignature signature)
{
final List<MSQResultsReport.ColumnAndType> retVal = new ArrayList<>(signature.size());
for (int i = 0; i < signature.size(); i++) {
retVal.add(
new MSQResultsReport.ColumnAndType(
signature.getColumnName(i),
signature.getColumnType(i).orElse(null)
)
);
}
return retVal;
}
} }

View File

@ -37,7 +37,6 @@ import org.apache.druid.msq.indexing.report.MSQResultsReport;
import org.apache.druid.msq.indexing.report.MSQTaskReport; import org.apache.druid.msq.indexing.report.MSQTaskReport;
import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; import org.apache.druid.msq.indexing.report.MSQTaskReportPayload;
import org.apache.druid.msq.sql.SqlTaskStatus; import org.apache.druid.msq.sql.SqlTaskStatus;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.http.SqlQuery; import org.apache.druid.sql.http.SqlQuery;
import org.apache.druid.testing.IntegrationTestingConfig; import org.apache.druid.testing.IntegrationTestingConfig;
import org.apache.druid.testing.clients.SqlResourceTestClient; import org.apache.druid.testing.clients.SqlResourceTestClient;
@ -203,13 +202,13 @@ public class MsqTestQueryHelper extends AbstractTestQueryHelper<MsqQueryWithResu
List<Map<String, Object>> actualResults = new ArrayList<>(); List<Map<String, Object>> actualResults = new ArrayList<>();
Yielder<Object[]> yielder = resultsReport.getResultYielder(); Yielder<Object[]> yielder = resultsReport.getResultYielder();
RowSignature rowSignature = resultsReport.getSignature(); List<MSQResultsReport.ColumnAndType> rowSignature = resultsReport.getSignature();
while (!yielder.isDone()) { while (!yielder.isDone()) {
Object[] row = yielder.get(); Object[] row = yielder.get();
Map<String, Object> rowWithFieldNames = new LinkedHashMap<>(); Map<String, Object> rowWithFieldNames = new LinkedHashMap<>();
for (int i = 0; i < row.length; ++i) { for (int i = 0; i < row.length; ++i) {
rowWithFieldNames.put(rowSignature.getColumnName(i), row[i]); rowWithFieldNames.put(rowSignature.get(i).getName(), row[i]);
} }
actualResults.add(rowWithFieldNames); actualResults.add(rowWithFieldNames);
yielder = yielder.next(null); yielder = yielder.next(null);

View File

@ -33,6 +33,7 @@ import java.util.HashSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set; import java.util.Set;
import java.util.Spliterator; import java.util.Spliterator;
import java.util.TreeSet; import java.util.TreeSet;
@ -129,6 +130,7 @@ public final class CollectionUtils
* can be replaced with Guava's implementation once Druid has upgraded its Guava dependency to a sufficient version. * can be replaced with Guava's implementation once Druid has upgraded its Guava dependency to a sufficient version.
* *
* @param expectedSize the expected size of the LinkedHashMap * @param expectedSize the expected size of the LinkedHashMap
*
* @return LinkedHashMap object with appropriate size based on callers expectedSize * @return LinkedHashMap object with appropriate size based on callers expectedSize
*/ */
@SuppressForbidden(reason = "java.util.LinkedHashMap#<init>(int)") @SuppressForbidden(reason = "java.util.LinkedHashMap#<init>(int)")
@ -184,6 +186,27 @@ public final class CollectionUtils
return result; return result;
} }
/**
* Like {@link Iterables#getOnlyElement(Iterable)}, but allows a customizable error message.
*/
public static <T, I extends Iterable<T>, X extends Throwable> T getOnlyElement(
final I iterable,
final Function<? super I, ? extends X> exceptionSupplier
) throws X
{
final Iterator<T> iterator = iterable.iterator();
try {
final T object = iterator.next();
if (iterator.hasNext()) {
throw exceptionSupplier.apply(iterable);
}
return object;
}
catch (NoSuchElementException e) {
throw exceptionSupplier.apply(iterable);
}
}
private CollectionUtils() private CollectionUtils()
{ {
} }

View File

@ -19,13 +19,18 @@
package org.apache.druid.utils; package org.apache.druid.utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import org.apache.druid.java.util.common.ISE;
import org.hamcrest.CoreMatchers;
import org.hamcrest.MatcherAssert;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.internal.matchers.ThrowableMessageMatcher;
import java.util.Collections;
import java.util.Set; import java.util.Set;
import static org.junit.Assert.assertEquals;
public class CollectionUtilsTest public class CollectionUtilsTest
{ {
// When Java 9 is allowed, use Set.of(). // When Java 9 is allowed, use Set.of().
@ -37,28 +42,57 @@ public class CollectionUtilsTest
@Test @Test
public void testSubtract() public void testSubtract()
{ {
assertEquals(empty, CollectionUtils.subtract(empty, empty)); Assert.assertEquals(empty, CollectionUtils.subtract(empty, empty));
assertEquals(abc, CollectionUtils.subtract(abc, empty)); Assert.assertEquals(abc, CollectionUtils.subtract(abc, empty));
assertEquals(empty, CollectionUtils.subtract(abc, abc)); Assert.assertEquals(empty, CollectionUtils.subtract(abc, abc));
assertEquals(abc, CollectionUtils.subtract(abc, efg)); Assert.assertEquals(abc, CollectionUtils.subtract(abc, efg));
assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd)); Assert.assertEquals(ImmutableSet.of("a"), CollectionUtils.subtract(abc, bcd));
} }
@Test @Test
public void testIntersect() public void testIntersect()
{ {
assertEquals(empty, CollectionUtils.intersect(empty, empty)); Assert.assertEquals(empty, CollectionUtils.intersect(empty, empty));
assertEquals(abc, CollectionUtils.intersect(abc, abc)); Assert.assertEquals(abc, CollectionUtils.intersect(abc, abc));
assertEquals(empty, CollectionUtils.intersect(abc, efg)); Assert.assertEquals(empty, CollectionUtils.intersect(abc, efg));
assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd)); Assert.assertEquals(ImmutableSet.of("b", "c"), CollectionUtils.intersect(abc, bcd));
} }
@Test @Test
public void testUnion() public void testUnion()
{ {
assertEquals(empty, CollectionUtils.union(empty, empty)); Assert.assertEquals(empty, CollectionUtils.union(empty, empty));
assertEquals(abc, CollectionUtils.union(abc, abc)); Assert.assertEquals(abc, CollectionUtils.union(abc, abc));
assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg)); Assert.assertEquals(ImmutableSet.of("a", "b", "c", "e", "f", "g"), CollectionUtils.union(abc, efg));
assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd)); Assert.assertEquals(ImmutableSet.of("a", "b", "c", "d"), CollectionUtils.union(abc, bcd));
}
@Test
public void testGetOnlyElement_empty()
{
final IllegalStateException e = Assert.assertThrows(
IllegalStateException.class,
() -> CollectionUtils.getOnlyElement(Collections.emptyList(), xs -> new ISE("oops"))
);
MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("oops")));
}
@Test
public void testGetOnlyElement_one()
{
Assert.assertEquals(
"a",
CollectionUtils.getOnlyElement(Collections.singletonList("a"), xs -> new ISE("oops"))
);
}
@Test
public void testGetOnlyElement_two()
{
final IllegalStateException e = Assert.assertThrows(
IllegalStateException.class,
() -> CollectionUtils.getOnlyElement(ImmutableList.of("a", "b"), xs -> new ISE("oops"))
);
MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("oops")));
} }
} }