MSQ window function: Take segment granularity into consideration to fix NPE issues with ingestion (#16854)

This PR changes the logic for window functions to use the resultShuffleSpecFactory for the last window stage.
This commit is contained in:
Akshat Jain 2024-08-21 10:06:04 +05:30 committed by GitHub
parent 2bd31603de
commit 0ce1b6b22f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 132 additions and 34 deletions

View File

@ -31,6 +31,7 @@ import org.apache.druid.frame.processor.FrameProcessors;
import org.apache.druid.frame.processor.FrameRowTooLargeException;
import org.apache.druid.frame.processor.ReturnOrAwait;
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.util.SettableLongVirtualColumn;
import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.java.util.common.Unit;
@ -51,6 +52,8 @@ import org.apache.druid.query.rowsandcols.semantic.ColumnSelectorFactoryMaker;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.NullableTypeStrategy;
import org.apache.druid.segment.column.RowSignature;
@ -85,6 +88,9 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor<Object>
private ResultRow outputRow = null;
private FrameWriter frameWriter = null;
private final VirtualColumns frameWriterVirtualColumns;
private final SettableLongVirtualColumn partitionBoostVirtualColumn;
// List of type strategies to compare the partition columns across rows.
// Type strategies are pushed in the same order as column types in frameReader.signature()
private final NullableTypeStrategy[] typeStrategies;
@ -119,6 +125,16 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor<Object>
for (int i = 0; i < frameReader.signature().size(); i++) {
typeStrategies[i] = frameReader.signature().getColumnType(i).get().getNullableStrategy();
}
// Get virtual columns to be added to the frame writer.
this.partitionBoostVirtualColumn = new SettableLongVirtualColumn(QueryKitUtils.PARTITION_BOOST_COLUMN);
final List<VirtualColumn> frameWriterVirtualColumns = new ArrayList<>();
final VirtualColumn segmentGranularityVirtualColumn =
QueryKitUtils.makeSegmentGranularityVirtualColumn(jsonMapper, query);
if (segmentGranularityVirtualColumn != null) {
frameWriterVirtualColumns.add(segmentGranularityVirtualColumn);
}
this.frameWriterVirtualColumns = VirtualColumns.create(frameWriterVirtualColumns);
}
@Override
@ -404,7 +420,9 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor<Object>
if (frameWriter == null) {
final ColumnSelectorFactoryMaker csfm = ColumnSelectorFactoryMaker.fromRAC(rac);
final ColumnSelectorFactory frameWriterColumnSelectorFactory = csfm.make(rowId);
frameWriter = frameWriterFactory.newFrameWriter(frameWriterColumnSelectorFactory);
final ColumnSelectorFactory frameWriterColumnSelectorFactoryWithVirtualColumns =
frameWriterVirtualColumns.wrap(frameWriterColumnSelectorFactory);
frameWriter = frameWriterFactory.newFrameWriter(frameWriterColumnSelectorFactoryWithVirtualColumns);
currentAllocatorCapacity = frameWriterFactory.allocatorCapacity();
}
}
@ -422,6 +440,7 @@ public class WindowOperatorQueryFrameProcessor implements FrameProcessor<Object>
final boolean didAddToFrame = frameWriter.addSelection();
if (didAddToFrame) {
rowId.incrementAndGet();
partitionBoostVirtualColumn.setValue(partitionBoostVirtualColumn.getValue() + 1);
} else if (frameWriter.getNumRows() == 0) {
throw new FrameRowTooLargeException(currentAllocatorCapacity);
} else {

View File

@ -24,6 +24,7 @@ import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.input.stage.StageInputSpec;
@ -105,8 +106,14 @@ public class WindowOperatorQueryKit implements QueryKit<WindowOperatorQuery>
final int firstStageNumber = Math.max(minStageNumber, queryDefBuilder.getNextStageNumber());
final WindowOperatorQuery queryToRun = (WindowOperatorQuery) originalQuery.withDataSource(dataSourcePlan.getNewDataSource());
final int maxRowsMaterialized;
// Get segment granularity from query context, and create ShuffleSpec and RowSignature to be used for the final window stage.
final Granularity segmentGranularity = QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, queryToRun.getContext());
final ClusterBy finalWindowClusterBy = computeClusterByForFinalWindowStage(segmentGranularity);
final ShuffleSpec finalWindowStageShuffleSpec = resultShuffleSpecFactory.build(finalWindowClusterBy, false);
final RowSignature finalWindowStageRowSignature = computeSignatureForFinalWindowStage(rowSignature, finalWindowClusterBy, segmentGranularity);
final int maxRowsMaterialized;
if (originalQuery.context() != null && originalQuery.context().containsKey(MultiStageQueryContext.MAX_ROWS_MATERIALIZED_IN_WINDOW)) {
maxRowsMaterialized = (int) originalQuery.context().get(MultiStageQueryContext.MAX_ROWS_MATERIALIZED_IN_WINDOW);
} else {
@ -122,13 +129,13 @@ public class WindowOperatorQueryKit implements QueryKit<WindowOperatorQuery>
queryDefBuilder.add(
StageDefinition.builder(firstStageNumber)
.inputs(new StageInputSpec(firstStageNumber - 1))
.signature(rowSignature)
.signature(finalWindowStageRowSignature)
.maxWorkerCount(maxWorkerCount)
.shuffleSpec(null)
.shuffleSpec(finalWindowStageShuffleSpec)
.processorFactory(new WindowOperatorQueryFrameProcessorFactory(
queryToRun,
queryToRun.getOperators(),
rowSignature,
finalWindowStageRowSignature,
maxRowsMaterialized,
Collections.emptyList()
))
@ -178,23 +185,22 @@ public class WindowOperatorQueryKit implements QueryKit<WindowOperatorQuery>
}
}
// find the shuffle spec of the next stage
// if it is the last stage set the next shuffle spec to single partition
if (i + 1 == operatorList.size()) {
nextShuffleSpec = MixShuffleSpec.instance();
} else {
nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(i + 1), maxWorkerCount);
}
final RowSignature intermediateSignature = bob.build();
final RowSignature stageRowSignature;
if (nextShuffleSpec == null) {
stageRowSignature = intermediateSignature;
if (i + 1 == operatorList.size()) {
stageRowSignature = finalWindowStageRowSignature;
nextShuffleSpec = finalWindowStageShuffleSpec;
} else {
stageRowSignature = QueryKitUtils.sortableSignature(
intermediateSignature,
nextShuffleSpec.clusterBy().getColumns()
);
nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(i + 1), maxWorkerCount);
if (nextShuffleSpec == null) {
stageRowSignature = intermediateSignature;
} else {
stageRowSignature = QueryKitUtils.sortableSignature(
intermediateSignature,
nextShuffleSpec.clusterBy().getColumns()
);
}
}
log.info("Using row signature [%s] for window stage.", stageRowSignature);
@ -346,4 +352,29 @@ public class WindowOperatorQueryKit implements QueryKit<WindowOperatorQuery>
}
return queryDefBuilder;
}
/**
* Computes the ClusterBy for the final window stage. We don't have to take the CLUSTERED BY columns into account,
* as they are handled as {@link org.apache.druid.query.scan.ScanQuery#orderBys}.
*/
private static ClusterBy computeClusterByForFinalWindowStage(Granularity segmentGranularity)
{
final List<KeyColumn> clusterByColumns = Collections.singletonList(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING));
return QueryKitUtils.clusterByWithSegmentGranularity(new ClusterBy(clusterByColumns, 0), segmentGranularity);
}
/**
* Computes the signature for the final window stage. The finalWindowClusterBy will always have the
* partition boost column as computed in {@link #computeClusterByForFinalWindowStage(Granularity)}.
*/
private static RowSignature computeSignatureForFinalWindowStage(RowSignature rowSignature, ClusterBy finalWindowClusterBy, Granularity segmentGranularity)
{
final RowSignature.Builder finalWindowStageRowSignatureBuilder = RowSignature.builder()
.addAll(rowSignature)
.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG);
return QueryKitUtils.sortableSignature(
QueryKitUtils.signatureWithSegmentGranularity(finalWindowStageRowSignatureBuilder.build(), segmentGranularity),
finalWindowClusterBy.getColumns()
);
}
}

View File

@ -1272,20 +1272,20 @@ public class MSQWindowTest extends MSQTestBase
.setExpectedResultRows(
NullHandling.replaceWithDefault() ?
ImmutableList.of(
new Object[]{"", 11.0},
new Object[]{"a", 5.0},
new Object[]{"", 11.0},
new Object[]{"", 11.0},
new Object[]{"a", 5.0},
new Object[]{"a", 5.0},
new Object[]{"abc", 5.0}
new Object[]{"abc", 5.0},
new Object[]{"", 11.0}
) :
ImmutableList.of(
new Object[]{null, 8.0},
new Object[]{"a", 5.0},
new Object[]{null, 8.0},
new Object[]{"", 3.0},
new Object[]{"a", 5.0},
new Object[]{"a", 5.0},
new Object[]{"abc", 5.0}
new Object[]{"abc", 5.0},
new Object[]{null, 8.0}
))
.setQueryContext(context)
.verifyResults();
@ -1935,11 +1935,11 @@ public class MSQWindowTest extends MSQTestBase
.build())
.setExpectedRowSignature(rowSignature)
.setExpectedResultRows(ImmutableList.of(
new Object[]{"Al Ain", 8L, 6334L},
new Object[]{"Dubai", 3L, 6334L},
new Object[]{"Dubai", 6323L, 6334L},
new Object[]{"Tirana", 26L, 26L},
new Object[]{"Benguela", 0L, 0L}
new Object[]{"Auburn", 0L, 1698L},
new Object[]{"Mexico City", 0L, 6136L},
new Object[]{"Seoul", 663L, 5582L},
new Object[]{"Tokyo", 0L, 12615L},
new Object[]{"Santiago", 161L, 401L}
))
.setQueryContext(context)
.verifyResults();
@ -2266,13 +2266,13 @@ public class MSQWindowTest extends MSQTestBase
2, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher.with().rows(13).bytes(1158).frames(1),
CounterSnapshotMatcher.with().rows(13).bytes(1379).frames(1),
2, 0, "output"
)
// Stage 3, Worker 0
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher.with().rows(13).bytes(1158).frames(1),
CounterSnapshotMatcher.with().rows(13).bytes(1327).frames(1),
3, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
@ -2285,4 +2285,42 @@ public class MSQWindowTest extends MSQTestBase
)
.verifyResults();
}
@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testReplaceWithPartitionedByDayOnWikipedia(String contextName, Map<String, Object> context)
{
RowSignature rowSignature = RowSignature.builder()
.add("__time", ColumnType.LONG)
.add("cityName", ColumnType.STRING)
.add("added", ColumnType.LONG)
.add("cc", ColumnType.LONG)
.build();
testIngestQuery().setSql(" REPLACE INTO foo1 OVERWRITE ALL\n"
+ "select __time, cityName, added, SUM(added) OVER () cc from wikipedia \n"
+ "where cityName IN ('Ahmedabad', 'Albuquerque')\n"
+ "GROUP BY __time, cityName, added\n"
+ "PARTITIONED BY DAY")
.setExpectedDataSource("foo1")
.setExpectedRowSignature(rowSignature)
.setQueryContext(context)
.setExpectedDestinationIntervals(Intervals.ONLY_ETERNITY)
.setExpectedResultRows(
ImmutableList.of(
new Object[]{1442055085114L, "Ahmedabad", 0L, 140L},
new Object[]{1442061929238L, "Ahmedabad", 0L, 140L},
new Object[]{1442069353218L, "Albuquerque", 129L, 140L},
new Object[]{1442069411614L, "Albuquerque", 9L, 140L},
new Object[]{1442097803851L, "Albuquerque", 2L, 140L}
)
)
.setExpectedSegments(ImmutableSet.of(SegmentId.of(
"foo1",
Intervals.of("2015-09-12/2015-09-13"),
"test",
0
)))
.verifyResults();
}
}

View File

@ -1 +1,6 @@
SELECT col2 , col8 , LAG(col8 ) OVER ( PARTITION BY col2 ORDER BY col2 , col8 nulls FIRST ) LAG_col8 FROM "fewRowsAllData.parquet" FETCH FIRST 15 ROWS ONLY
SELECT
col2, col8,
LAG(col8) OVER (PARTITION BY col2 ORDER BY col2, col8 nulls FIRST) LAG_col8
FROM "fewRowsAllData.parquet"
ORDER BY col2, col8
FETCH FIRST 15 ROWS ONLY

View File

@ -1 +1,6 @@
SELECT col2 , col8 , LEAD(col8 ) OVER ( PARTITION BY col2 ORDER BY col8 nulls FIRST ) LEAD_col8 FROM "fewRowsAllData.parquet" limit 10
SELECT
col2, col8,
LEAD(col8) OVER (PARTITION BY col2 ORDER BY col8 nulls FIRST) LEAD_col8
FROM "fewRowsAllData.parquet"
ORDER BY col2, col8
limit 10