Fixes a bug when running queries with a limit clause (#16643)

Add a shuffling based on the resultShuffleSpecFactory after a limit processor depending on the query destination. LimitFrameProcessors currently do not update the partition boosting column, so we also add the boost column to the previous stage, if one is required.
This commit is contained in:
Adarsh Sanjeev 2024-07-09 14:29:12 +05:30 committed by GitHub
parent a9bd0eea2a
commit af5399cd9d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 421 additions and 80 deletions

View File

@ -101,9 +101,7 @@ import org.apache.druid.msq.indexing.MSQTuningConfig;
import org.apache.druid.msq.indexing.WorkerCount; import org.apache.druid.msq.indexing.WorkerCount;
import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.client.ControllerChatHandler;
import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination;
import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
import org.apache.druid.msq.indexing.destination.ExportMSQDestination; import org.apache.druid.msq.indexing.destination.ExportMSQDestination;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.indexing.error.CanceledFault; import org.apache.druid.msq.indexing.error.CanceledFault;
import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault; import org.apache.druid.msq.indexing.error.CannotParseExternalDataFault;
import org.apache.druid.msq.indexing.error.FaultsExceededChecker; import org.apache.druid.msq.indexing.error.FaultsExceededChecker;
@ -1828,9 +1826,9 @@ public class ControllerImpl implements Controller
); );
return builder.build(); return builder.build();
} else if (querySpec.getDestination() instanceof TaskReportMSQDestination) { } else if (MSQControllerTask.writeFinalResultsToTaskReport(querySpec)) {
return queryDef; return queryDef;
} else if (querySpec.getDestination() instanceof DurableStorageMSQDestination) { } else if (MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)) {
// attaching new query results stage if the final stage does sort during shuffle so that results are ordered. // attaching new query results stage if the final stage does sort during shuffle so that results are ordered.
StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition(); StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition();
@ -2933,12 +2931,12 @@ public class ControllerImpl implements Controller
final InputChannelFactory inputChannelFactory; final InputChannelFactory inputChannelFactory;
if (queryKernelConfig.isDurableStorage() || MSQControllerTask.writeResultsToDurableStorage(querySpec)) { if (queryKernelConfig.isDurableStorage() || MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)) {
inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation(
queryId(), queryId(),
MSQTasks.makeStorageConnector(context.injector()), MSQTasks.makeStorageConnector(context.injector()),
closer, closer,
MSQControllerTask.writeResultsToDurableStorage(querySpec) MSQControllerTask.writeFinalStageResultsToDurableStorage(querySpec)
); );
} else { } else {
inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds); inputChannelFactory = new WorkerInputChannelFactory(netClient, () -> taskIds);

View File

@ -52,6 +52,7 @@ import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination;
import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination;
import org.apache.druid.msq.indexing.destination.ExportMSQDestination; import org.apache.druid.msq.indexing.destination.ExportMSQDestination;
import org.apache.druid.msq.indexing.destination.MSQDestination; import org.apache.druid.msq.indexing.destination.MSQDestination;
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContext;
import org.apache.druid.rpc.ServiceClientFactory; import org.apache.druid.rpc.ServiceClientFactory;
@ -305,16 +306,38 @@ public class MSQControllerTask extends AbstractTask implements ClientTaskQuery,
return querySpec.getDestination().getDestinationResource(); return querySpec.getDestination().getDestinationResource();
} }
/**
* Checks whether the task is an ingestion into a Druid datasource.
*/
public static boolean isIngestion(final MSQSpec querySpec) public static boolean isIngestion(final MSQSpec querySpec)
{ {
return querySpec.getDestination() instanceof DataSourceMSQDestination; return querySpec.getDestination() instanceof DataSourceMSQDestination;
} }
/**
* Checks whether the task is an export into external files.
*/
public static boolean isExport(final MSQSpec querySpec) public static boolean isExport(final MSQSpec querySpec)
{ {
return querySpec.getDestination() instanceof ExportMSQDestination; return querySpec.getDestination() instanceof ExportMSQDestination;
} }
/**
* Checks whether the task is an async query which writes frame files containing the final results into durable storage.
*/
public static boolean writeFinalStageResultsToDurableStorage(final MSQSpec querySpec)
{
return querySpec.getDestination() instanceof DurableStorageMSQDestination;
}
/**
* Checks whether the task is an async query which writes frame files containing the final results into durable storage.
*/
public static boolean writeFinalResultsToTaskReport(final MSQSpec querySpec)
{
return querySpec.getDestination() instanceof TaskReportMSQDestination;
}
/** /**
* Returns true if the task reads from the same table as the destination. In this case, we would prefer to fail * Returns true if the task reads from the same table as the destination. In this case, we would prefer to fail
* instead of reading any unused segments to ensure that old data is not read. * instead of reading any unused segments to ensure that old data is not read.
@ -330,11 +353,6 @@ public class MSQControllerTask extends AbstractTask implements ClientTaskQuery,
} }
} }
public static boolean writeResultsToDurableStorage(final MSQSpec querySpec)
{
return querySpec.getDestination() instanceof DurableStorageMSQDestination;
}
@Override @Override
public LookupLoadingSpec getLookupLoadingSpec() public LookupLoadingSpec getLookupLoadingSpec()
{ {

View File

@ -185,13 +185,14 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
); );
if (doLimitOrOffset) { if (doLimitOrOffset) {
final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false);
final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec(); final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec();
queryDefBuilder.add( queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 2) StageDefinition.builder(firstStageNumber + 2)
.inputs(new StageInputSpec(firstStageNumber + 1)) .inputs(new StageInputSpec(firstStageNumber + 1))
.signature(resultSignature) .signature(resultSignature)
.maxWorkerCount(1) .maxWorkerCount(1)
.shuffleSpec(null) // no shuffling should be required after a limit processor. .shuffleSpec(finalShuffleSpec)
.processorFactory( .processorFactory(
new OffsetLimitFrameProcessorFactory( new OffsetLimitFrameProcessorFactory(
limitSpec.getOffset(), limitSpec.getOffset(),
@ -224,12 +225,13 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
); );
if (doLimitOrOffset) { if (doLimitOrOffset) {
final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec(); final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec();
final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false);
queryDefBuilder.add( queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 2) StageDefinition.builder(firstStageNumber + 2)
.inputs(new StageInputSpec(firstStageNumber + 1)) .inputs(new StageInputSpec(firstStageNumber + 1))
.signature(resultSignature) .signature(resultSignature)
.maxWorkerCount(1) .maxWorkerCount(1)
.shuffleSpec(null) .shuffleSpec(finalShuffleSpec)
.processorFactory( .processorFactory(
new OffsetLimitFrameProcessorFactory( new OffsetLimitFrameProcessorFactory(
limitSpec.getOffset(), limitSpec.getOffset(),

View File

@ -34,6 +34,7 @@ import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.querykit.DataSourcePlan; import org.apache.druid.msq.querykit.DataSourcePlan;
import org.apache.druid.msq.querykit.QueryKit; import org.apache.druid.msq.querykit.QueryKit;
import org.apache.druid.msq.querykit.QueryKitUtils; import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.msq.querykit.ShuffleSpecFactories;
import org.apache.druid.msq.querykit.ShuffleSpecFactory; import org.apache.druid.msq.querykit.ShuffleSpecFactory;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.msq.util.MultiStageQueryContext;
@ -111,18 +112,8 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
final ScanQuery queryToRun = originalQuery.withDataSource(dataSourcePlan.getNewDataSource()); final ScanQuery queryToRun = originalQuery.withDataSource(dataSourcePlan.getNewDataSource());
final int firstStageNumber = Math.max(minStageNumber, queryDefBuilder.getNextStageNumber()); final int firstStageNumber = Math.max(minStageNumber, queryDefBuilder.getNextStageNumber());
final RowSignature scanSignature = getAndValidateSignature(queryToRun, jsonMapper); final RowSignature scanSignature = getAndValidateSignature(queryToRun, jsonMapper);
final ShuffleSpec shuffleSpec;
final RowSignature signatureToUse;
final boolean hasLimitOrOffset = queryToRun.isLimited() || queryToRun.getScanRowsOffset() > 0; final boolean hasLimitOrOffset = queryToRun.isLimited() || queryToRun.getScanRowsOffset() > 0;
// We ignore the resultShuffleSpecFactory in case:
// 1. There is no cluster by
// 2. There is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
if (queryToRun.getOrderBys().isEmpty() && hasLimitOrOffset) {
shuffleSpec = MixShuffleSpec.instance();
signatureToUse = scanSignature;
} else {
final RowSignature.Builder signatureBuilder = RowSignature.builder().addAll(scanSignature); final RowSignature.Builder signatureBuilder = RowSignature.builder().addAll(scanSignature);
final Granularity segmentGranularity = final Granularity segmentGranularity =
QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, queryToRun.getContext()); QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, queryToRun.getContext());
@ -159,21 +150,39 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG); signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG);
} }
final ClusterBy clusterBy = final ClusterBy clusterBy =
QueryKitUtils.clusterByWithSegmentGranularity(new ClusterBy(clusterByColumns, 0), segmentGranularity); QueryKitUtils.clusterByWithSegmentGranularity(new ClusterBy(clusterByColumns, 0), segmentGranularity);
shuffleSpec = resultShuffleSpecFactory.build(clusterBy, false); final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(clusterBy, false);
signatureToUse = QueryKitUtils.sortableSignature(
final RowSignature signatureToUse = QueryKitUtils.sortableSignature(
QueryKitUtils.signatureWithSegmentGranularity(signatureBuilder.build(), segmentGranularity), QueryKitUtils.signatureWithSegmentGranularity(signatureBuilder.build(), segmentGranularity),
clusterBy.getColumns() clusterBy.getColumns()
); );
ShuffleSpec scanShuffleSpec;
if (!hasLimitOrOffset) {
// If there is no limit spec, apply the final shuffling here itself. This will ensure partition sizes etc are respected.
scanShuffleSpec = finalShuffleSpec;
} else {
// If there is a limit spec, check if there are any non-boost columns to sort in.
boolean requiresSort = clusterByColumns.stream()
.anyMatch(keyColumn -> !QueryKitUtils.PARTITION_BOOST_COLUMN.equals(keyColumn.columnName()));
if (requiresSort) {
// If yes, do a sort into a single partition.
scanShuffleSpec = ShuffleSpecFactories.singlePartition().build(clusterBy, false);
} else {
// If the only clusterBy column is the boost column, we just use a mix shuffle to avoid unused shuffling.
// Note that we still need the boost column to be present in the row signature, since the limit stage would
// need it to be populated to do its own shuffling later.
scanShuffleSpec = MixShuffleSpec.instance();
}
} }
queryDefBuilder.add( queryDefBuilder.add(
StageDefinition.builder(Math.max(minStageNumber, queryDefBuilder.getNextStageNumber())) StageDefinition.builder(Math.max(minStageNumber, queryDefBuilder.getNextStageNumber()))
.inputs(dataSourcePlan.getInputSpecs()) .inputs(dataSourcePlan.getInputSpecs())
.broadcastInputs(dataSourcePlan.getBroadcastInputs()) .broadcastInputs(dataSourcePlan.getBroadcastInputs())
.shuffleSpec(shuffleSpec) .shuffleSpec(scanShuffleSpec)
.signature(signatureToUse) .signature(signatureToUse)
.maxWorkerCount(dataSourcePlan.isSingleWorker() ? 1 : maxWorkerCount) .maxWorkerCount(dataSourcePlan.isSingleWorker() ? 1 : maxWorkerCount)
.processorFactory(new ScanQueryFrameProcessorFactory(queryToRun)) .processorFactory(new ScanQueryFrameProcessorFactory(queryToRun))
@ -185,7 +194,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
.inputs(new StageInputSpec(firstStageNumber)) .inputs(new StageInputSpec(firstStageNumber))
.signature(signatureToUse) .signature(signatureToUse)
.maxWorkerCount(1) .maxWorkerCount(1)
.shuffleSpec(null) // no shuffling should be required after a limit processor. .shuffleSpec(finalShuffleSpec) // Apply the final shuffling after limit spec.
.processorFactory( .processorFactory(
new OffsetLimitFrameProcessorFactory( new OffsetLimitFrameProcessorFactory(
queryToRun.getScanRowsOffset(), queryToRun.getScanRowsOffset(),

View File

@ -316,6 +316,52 @@ public class MSQExportTest extends MSQTestBase
} }
} }
@Test
public void testExportWithLimit() throws IOException
{
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG).build();
File exportDir = newTempFolder("export");
Map<String, Object> queryContext = new HashMap<>(DEFAULT_MSQ_CONTEXT);
queryContext.put(MultiStageQueryContext.CTX_ROWS_PER_PAGE, 1);
final String sql = StringUtils.format("insert into extern(local(exportPath=>'%s')) as csv select cnt, dim1 from foo limit 3", exportDir.getAbsolutePath());
testIngestQuery().setSql(sql)
.setExpectedDataSource("foo1")
.setQueryContext(queryContext)
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(ImmutableSet.of())
.setExpectedResultRows(ImmutableList.of())
.verifyResults();
Assert.assertEquals(
ImmutableList.of(
"cnt,dim1",
"1,"
),
readResultsFromFile(new File(exportDir, "query-test-query-worker0-partition0.csv"))
);
Assert.assertEquals(
ImmutableList.of(
"cnt,dim1",
"1,10.1"
),
readResultsFromFile(new File(exportDir, "query-test-query-worker0-partition1.csv"))
);
Assert.assertEquals(
ImmutableList.of(
"cnt,dim1",
"1,2"
),
readResultsFromFile(new File(exportDir, "query-test-query-worker0-partition2.csv"))
);
}
private void verifyManifestFile(File exportDir, List<File> resultFiles) throws IOException private void verifyManifestFile(File exportDir, List<File> resultFiles) throws IOException
{ {
final File manifestFile = new File(exportDir, ExportMetadataManager.MANIFEST_FILE); final File manifestFile = new File(exportDir, ExportMetadataManager.MANIFEST_FILE);

View File

@ -46,6 +46,7 @@ import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
import org.apache.druid.msq.indexing.error.TooManySegmentsInTimeChunkFault; import org.apache.druid.msq.indexing.error.TooManySegmentsInTimeChunkFault;
import org.apache.druid.msq.test.MSQTestBase; import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestTaskActionClient; import org.apache.druid.msq.test.MSQTestTaskActionClient;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec; import org.apache.druid.segment.realtime.appenderator.SegmentIdWithShardSpec;
@ -291,7 +292,7 @@ public class MSQFaultsTest extends MSQTestBase
{ {
Map<String, Object> context = ImmutableMap.<String, Object>builder() Map<String, Object> context = ImmutableMap.<String, Object>builder()
.putAll(DEFAULT_MSQ_CONTEXT) .putAll(DEFAULT_MSQ_CONTEXT)
.put("rowsPerSegment", 1) .put(MultiStageQueryContext.CTX_ROWS_PER_SEGMENT, 1)
.build(); .build();

View File

@ -1455,7 +1455,7 @@ public class MSQInsertTest extends MSQTestBase
+ "SELECT __time, m1 " + "SELECT __time, m1 "
+ "FROM foo " + "FROM foo "
+ "LIMIT 50 " + "LIMIT 50 "
+ "OFFSET 10" + "OFFSET 10 "
+ "PARTITIONED BY ALL TIME") + "PARTITIONED BY ALL TIME")
.setExpectedValidationErrorMatcher( .setExpectedValidationErrorMatcher(
invalidSqlContains("INSERT and REPLACE queries cannot have an OFFSET") invalidSqlContains("INSERT and REPLACE queries cannot have an OFFSET")
@ -1464,6 +1464,44 @@ public class MSQInsertTest extends MSQTestBase
.verifyPlanningErrors(); .verifyPlanningErrors();
} }
@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testInsertOnFoo1WithLimit(String contextName, Map<String, Object> context)
{
Map<String, Object> queryContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put(MultiStageQueryContext.CTX_ROWS_PER_SEGMENT, 2)
.build();
List<Object[]> expectedRows = ImmutableList.of(
new Object[]{946771200000L, "10.1", 1L},
new Object[]{978307200000L, "1", 1L},
new Object[]{946857600000L, "2", 1L},
new Object[]{978480000000L, "abc", 1L}
);
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG)
.build();
testIngestQuery().setSql(
"insert into foo1 select __time, dim1, cnt from foo where dim1 != '' limit 4 partitioned by all clustered by dim1")
.setExpectedDataSource("foo1")
.setQueryContext(queryContext)
.setExpectedRowSignature(rowSignature)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0), SegmentId.of("foo1", Intervals.ETERNITY, "test", 1)))
.setExpectedResultRows(expectedRows)
.setExpectedMSQSegmentReport(
new MSQSegmentReport(
NumberedShardSpec.class.getSimpleName(),
"Using NumberedShardSpec to generate segments since the query is inserting rows."
)
)
.verifyResults();
}
@MethodSource("data") @MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}") @ParameterizedTest(name = "{index}:with context {0}")
public void testCorrectNumberOfWorkersUsedAutoModeWithoutBytesLimit(String contextName, Map<String, Object> context) throws IOException public void testCorrectNumberOfWorkersUsedAutoModeWithoutBytesLimit(String contextName, Map<String, Object> context) throws IOException

View File

@ -906,6 +906,51 @@ public class MSQReplaceTest extends MSQTestBase
.verifyResults(); .verifyResults();
} }
@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testReplaceOnFoo1WithLimit(String contextName, Map<String, Object> context)
{
Map<String, Object> queryContext = ImmutableMap.<String, Object>builder()
.putAll(context)
.put(MultiStageQueryContext.CTX_ROWS_PER_SEGMENT, 2)
.build();
List<Object[]> expectedRows = ImmutableList.of(
new Object[]{946684800000L, NullHandling.sqlCompatible() ? "" : null},
new Object[]{978307200000L, "1"},
new Object[]{946771200000L, "10.1"},
new Object[]{946857600000L, "2"}
);
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("dim1", ColumnType.STRING)
.build();
testIngestQuery().setSql(
"REPLACE INTO \"foo1\" OVERWRITE ALL\n"
+ "SELECT\n"
+ " \"__time\",\n"
+ " \"dim1\"\n"
+ "FROM foo\n"
+ "LIMIT 4\n"
+ "PARTITIONED BY ALL\n"
+ "CLUSTERED BY dim1")
.setExpectedDataSource("foo1")
.setQueryContext(queryContext)
.setExpectedRowSignature(rowSignature)
.setExpectedShardSpec(DimensionRangeShardSpec.class)
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0), SegmentId.of("foo1", Intervals.ETERNITY, "test", 1)))
.setExpectedResultRows(expectedRows)
.setExpectedMSQSegmentReport(
new MSQSegmentReport(
DimensionRangeShardSpec.class.getSimpleName(),
"Using RangeShardSpec to generate segments."
)
)
.verifyResults();
}
@MethodSource("data") @MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}") @ParameterizedTest(name = "{index}:with context {0}")
public void testReplaceTimeChunksLargerThanData(String contextName, Map<String, Object> context) public void testReplaceTimeChunksLargerThanData(String contextName, Map<String, Object> context)

View File

@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.data.input.impl.CsvInputFormat; import org.apache.druid.data.input.impl.CsvInputFormat;
import org.apache.druid.data.input.impl.InlineInputSource;
import org.apache.druid.data.input.impl.JsonInputFormat; import org.apache.druid.data.input.impl.JsonInputFormat;
import org.apache.druid.data.input.impl.LocalInputSource; import org.apache.druid.data.input.impl.LocalInputSource;
import org.apache.druid.data.input.impl.systemfield.SystemFields; import org.apache.druid.data.input.impl.systemfield.SystemFields;
@ -624,8 +625,17 @@ public class MSQSelectTest extends MSQTestBase
.add("dim1", ColumnType.STRING) .add("dim1", ColumnType.STRING)
.build(); .build();
final ImmutableList<Object[]> expectedResults = ImmutableList.of(
new Object[]{1L, ""},
new Object[]{1L, "10.1"},
new Object[]{1L, "2"},
new Object[]{1L, "1"},
new Object[]{1L, "def"},
new Object[]{1L, "abc"}
);
testSelectQuery() testSelectQuery()
.setSql("select cnt,dim1 from foo limit 10") .setSql("select cnt, dim1 from foo limit 10")
.setExpectedMSQSpec( .setExpectedMSQSpec(
MSQSpec.builder() MSQSpec.builder()
.query( .query(
@ -646,6 +656,7 @@ public class MSQSelectTest extends MSQTestBase
) )
.setQueryContext(context) .setQueryContext(context)
.setExpectedRowSignature(resultSignature) .setExpectedRowSignature(resultSignature)
.setExpectedResultRows(expectedResults)
.setExpectedCountersForStageWorkerChannel( .setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher CounterSnapshotMatcher
.with().totalFiles(1), .with().totalFiles(1),
@ -653,22 +664,31 @@ public class MSQSelectTest extends MSQTestBase
) )
.setExpectedCountersForStageWorkerChannel( .setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher CounterSnapshotMatcher
.with().rows(6).frames(1), .with().rows(6),
0, 0, "output" 0, 0, "output"
) )
.setExpectedCountersForStageWorkerChannel( .setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher CounterSnapshotMatcher
.with().rows(6).frames(1), .with().rows(6),
0, 0, "shuffle" 0, 0, "shuffle"
) )
.setExpectedResultRows(ImmutableList.of( .setExpectedCountersForStageWorkerChannel(
new Object[]{1L, ""}, CounterSnapshotMatcher
new Object[]{1L, "10.1"}, .with().rows(6),
new Object[]{1L, "2"}, 1, 0, "input0"
new Object[]{1L, "1"}, )
new Object[]{1L, "def"}, .setExpectedCountersForStageWorkerChannel(
new Object[]{1L, "abc"} CounterSnapshotMatcher
)).verifyResults(); .with().rows(6),
1, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(!context.containsKey(MultiStageQueryContext.CTX_ROWS_PER_PAGE) ? new long[] {6} : new long[] {2, 2, 2}),
1, 0, "shuffle"
)
.setExpectedResultRows(expectedResults)
.verifyResults();
} }
@MethodSource("data") @MethodSource("data")
@ -1699,6 +1719,166 @@ public class MSQSelectTest extends MSQTestBase
)).verifyResults(); )).verifyResults();
} }
@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testGroupByWithLimit(String contextName, Map<String, Object> context)
{
RowSignature expectedResultSignature = RowSignature.builder()
.add("dim1", ColumnType.STRING)
.add("cnt", ColumnType.LONG)
.build();
GroupByQuery query = GroupByQuery.builder()
.setDataSource(CalciteTests.DATASOURCE1)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0")))
.setAggregatorSpecs(
aggregators(
new CountAggregatorFactory(
"a0"
)
)
)
.setDimFilter(not(equality("dim1", "", ColumnType.STRING)))
.setLimit(1)
.setContext(context)
.build();
testSelectQuery()
.setSql("SELECT dim1, cnt FROM (SELECT dim1, COUNT(*) AS cnt FROM foo GROUP BY dim1 HAVING dim1 != '' LIMIT 1) LIMIT 20")
.setExpectedMSQSpec(MSQSpec.builder()
.query(query)
.columnMappings(new ColumnMappings(ImmutableList.of(
new ColumnMapping("d0", "dim1"),
new ColumnMapping("a0", "cnt")
)))
.tuningConfig(MSQTuningConfig.defaultConfig())
.destination(isDurableStorageDestination(contextName, context)
? DurableStorageMSQDestination.INSTANCE
: TaskReportMSQDestination.INSTANCE)
.build())
.setExpectedRowSignature(expectedResultSignature)
.setQueryContext(context)
.setExpectedResultRows(ImmutableList.of(
new Object[]{"1", 1L}
)).verifyResults();
}
@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testGroupByWithLimitAndOrdering(String contextName, Map<String, Object> context)
{
RowSignature rowSignature = RowSignature.builder()
.add("dim1", ColumnType.STRING)
.add("count", ColumnType.LONG)
.build();
GroupByQuery query = GroupByQuery.builder()
.setDataSource(
new ExternalDataSource(
new InlineInputSource("dim1\nabc\nxyz\ndef\nxyz\nabc\nxyz\nabc\nxyz\ndef\nbbb\naaa"),
new CsvInputFormat(null, null, null, true, 0),
RowSignature.builder().add("dim1", ColumnType.STRING).build()
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setGranularity(Granularities.ALL)
.addOrderByColumn(new OrderByColumnSpec("a0", OrderByColumnSpec.Direction.DESCENDING, StringComparators.NUMERIC))
.addOrderByColumn(new OrderByColumnSpec("d0", OrderByColumnSpec.Direction.ASCENDING, StringComparators.LEXICOGRAPHIC))
.setDimensions(dimensions(new DefaultDimensionSpec("dim1", "d0")))
.setAggregatorSpecs(
aggregators(
new CountAggregatorFactory(
"a0"
)
)
)
.setLimit(4)
.setContext(context)
.build();
List<Object[]> expectedRows = ImmutableList.of(
new Object[]{"xyz", 4L},
new Object[]{"abc", 3L},
new Object[]{"def", 2L},
new Object[]{"aaa", 1L}
);
testSelectQuery()
.setSql("WITH \"ext\" AS (\n"
+ " SELECT *\n"
+ " FROM TABLE(\n"
+ " EXTERN(\n"
+ " '{\"type\":\"inline\",\"data\":\"dim1\\nabc\\nxyz\\ndef\\nxyz\\nabc\\nxyz\\nabc\\nxyz\\ndef\\nbbb\\naaa\"}',\n"
+ " '{\"type\":\"csv\",\"findColumnsFromHeader\":true}'\n"
+ " )\n"
+ " ) EXTEND (\"dim1\" VARCHAR)\n"
+ ")\n"
+ "SELECT\n"
+ " \"dim1\",\n"
+ " COUNT(*) AS \"count\"\n"
+ "FROM \"ext\"\n"
+ "GROUP BY 1\n"
+ "ORDER BY 2 DESC, 1\n"
+ "LIMIT 4\n")
.setExpectedMSQSpec(MSQSpec.builder()
.query(query)
.columnMappings(new ColumnMappings(ImmutableList.of(
new ColumnMapping("d0", "dim1"),
new ColumnMapping("a0", "count")
)))
.tuningConfig(MSQTuningConfig.defaultConfig())
.destination(isDurableStorageDestination(contextName, context)
? DurableStorageMSQDestination.INSTANCE
: TaskReportMSQDestination.INSTANCE)
.build())
.setExpectedRowSignature(rowSignature)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().totalFiles(1),
0, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(5),
0, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(5),
1, 0, "shuffle"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(5),
1, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(5),
1, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(5),
2, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(4),
2, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(!context.containsKey(MultiStageQueryContext.CTX_ROWS_PER_PAGE) ? new long[] {4} : new long[] {2, 2}),
2, 0, "shuffle"
)
.setQueryContext(context)
.setExpectedResultRows(expectedRows)
.verifyResults();
}
@MethodSource("data") @MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}") @ParameterizedTest(name = "{index}:with context {0}")
public void testHavingOnApproximateCountDistinct(String contextName, Map<String, Object> context) public void testHavingOnApproximateCountDistinct(String contextName, Map<String, Object> context)

View File

@ -1288,6 +1288,10 @@ public class MSQTestBase extends BaseCalciteQueryTest
.stream() .stream()
.filter(segmentId -> segmentId.getInterval() .filter(segmentId -> segmentId.getInterval()
.contains((Long) row[0])) .contains((Long) row[0]))
.filter(segmentId -> {
List<List<Object>> lists = segmentIdVsOutputRowsMap.get(segmentId);
return lists.contains(Arrays.asList(row));
})
.collect(Collectors.toList()); .collect(Collectors.toList());
if (diskSegmentList.size() != 1) { if (diskSegmentList.size() != 1) {
throw new IllegalStateException("Single key in multiple partitions"); throw new IllegalStateException("Single key in multiple partitions");