Sort-merge join and hash shuffles for MSQ. (#13506)

* Sort-merge join and hash shuffles for MSQ.

The main changes are in the processing, multi-stage-query, and sql modules.

processing module:

1) Rename SortColumn to KeyColumn, replace boolean descending with KeyOrder.
   This makes it nicer to model hash keys, which use KeyOrder.NONE.

2) Add nullability checkers to the FieldReader interface, and an
   "isPartiallyNullKey" method to FrameComparisonWidget. The join
   processor uses this to detect null keys.

3) Add WritableFrameChannel.isClosed and OutputChannel.isReadableChannelReady
   so callers can tell which OutputChannels are ready for reading and which
   aren't.

4) Specialize FrameProcessors.makeCursor to return FrameCursor, a random-access
   implementation. The join processor uses this to rewind when it needs to
   replay a set of rows with a particular key.

5) Add MemoryAllocatorFactory, which is embedded inside FrameWriterFactory
   instead of a particular MemoryAllocator. This allows FrameWriterFactory
   to be shared in more scenarios.

multi-stage-query module:

1) ShuffleSpec: Add hash-based shuffles. New enum ShuffleKind helps callers
   figure out what kind of shuffle is happening. The change from SortColumn
   to KeyColumn allows ClusterBy to be used for both hash-based and sort-based
   shuffling.

2) WorkerImpl: Add ability to handle hash-based shuffles. Refactor the logic
   to be more readable by moving the work-order-running code to the inner
   class RunWorkOrder, and the shuffle-pipeline-building code to the inner
   class ShufflePipelineBuilder.

3) Add SortMergeJoinFrameProcessor and factory.

4) WorkerMemoryParameters: Adjust logic to reserve space for output frames
   for hash partitioning. (We need one frame per partition.)

sql module:

1) Add sqlJoinAlgorithm context parameter; can be "broadcast" or
   "sortMerge". With native, it must always be "broadcast", or it's a
   validation error. MSQ supports both. Default is "broadcast" in
   both engines.

2) Validate that MSQs do not use broadcast join with RIGHT or FULL join,
   as results are not correct for broadcast join with those types. Allow
   this in native for two reasons: legacy (the docs caution against it,
   but it's always been allowed), and the fact that it actually *does*
   generate correct results in native when the join is processed on the
   Broker. It is much less likely that MSQ will plan in such a way that
   generates correct results.

3) Remove subquery penalty in DruidJoinQueryRel when using sort-merge
   join, because subqueries are always required, so there's no reason
   to penalize them.

4) Move previously-disabled join reordering and manipulation rules to
   FANCY_JOIN_RULES, and enable them when using sort-merge join. Helps
   get to better plans where projections and filters are pushed down.

* Work around compiler problem.

* Updates from static analysis.

* Fix @param tag.

* Fix declared exception.

* Fix spelling.

* Minor adjustments.

* wip

* Merge fixups

* fixes

* Fix CalciteSelectQueryMSQTest

* Empty keys are sortable.

* Address comments from code review. Rename mux -> mix.

* Restore inspection config.

* Restore original doc.

* Reorder imports.

* Adjustments

* Fix.

* Fix imports.

* Adjustments from review.

* Update header.

* Adjust docs.
This commit is contained in:
Gian Merlino 2023-03-08 14:19:39 -08:00 committed by GitHub
parent f0fb094cc7
commit 82f7a56475
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
153 changed files with 6853 additions and 1648 deletions

View File

@ -592,6 +592,7 @@ The following table lists the context parameters for the MSQ task engine:
| `maxNumTasks` | SELECT, INSERT, REPLACE<br /><br />The maximum total number of tasks to launch, including the controller task. The lowest possible value for this setting is 2: one controller and one worker. All tasks must be able to launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` error code after approximately 10 minutes.<br /><br />May also be provided as `numTasks`. If both are present, `maxNumTasks` takes priority.| 2 | | `maxNumTasks` | SELECT, INSERT, REPLACE<br /><br />The maximum total number of tasks to launch, including the controller task. The lowest possible value for this setting is 2: one controller and one worker. All tasks must be able to launch simultaneously. If they cannot, the query returns a `TaskStartTimeout` error code after approximately 10 minutes.<br /><br />May also be provided as `numTasks`. If both are present, `maxNumTasks` takes priority.| 2 |
| `taskAssignment` | SELECT, INSERT, REPLACE<br /><br />Determines how many tasks to use. Possible values include: <ul><li>`max`: Uses as many tasks as possible, up to `maxNumTasks`.</li><li>`auto`: When file sizes can be determined through directory listing (for example: local files, S3, GCS, HDFS) uses as few tasks as possible without exceeding 10 GiB or 10,000 files per task, unless exceeding these limits is necessary to stay within `maxNumTasks`. When file sizes cannot be determined through directory listing (for example: http), behaves the same as `max`.</li></ul> | `max` | | `taskAssignment` | SELECT, INSERT, REPLACE<br /><br />Determines how many tasks to use. Possible values include: <ul><li>`max`: Uses as many tasks as possible, up to `maxNumTasks`.</li><li>`auto`: When file sizes can be determined through directory listing (for example: local files, S3, GCS, HDFS) uses as few tasks as possible without exceeding 10 GiB or 10,000 files per task, unless exceeding these limits is necessary to stay within `maxNumTasks`. When file sizes cannot be determined through directory listing (for example: http), behaves the same as `max`.</li></ul> | `max` |
| `finalizeAggregations` | SELECT, INSERT, REPLACE<br /><br />Determines the type of aggregation to return. If true, Druid finalizes the results of complex aggregations that directly appear in query results. If false, Druid returns the aggregation's intermediate type rather than finalized type. This parameter is useful during ingestion, where it enables storing sketches directly in Druid tables. For more information about aggregations, see [SQL aggregation functions](../querying/sql-aggregations.md). | true | | `finalizeAggregations` | SELECT, INSERT, REPLACE<br /><br />Determines the type of aggregation to return. If true, Druid finalizes the results of complex aggregations that directly appear in query results. If false, Druid returns the aggregation's intermediate type rather than finalized type. This parameter is useful during ingestion, where it enables storing sketches directly in Druid tables. For more information about aggregations, see [SQL aggregation functions](../querying/sql-aggregations.md). | true |
| `sqlJoinAlgorithm` | SELECT, INSERT, REPLACE<br /><br />Algorithm to use for JOIN. Use `broadcast` (the default) for broadcast hash join or `sortMerge` for sort-merge join. Affects all JOIN operations in the query. See [Joins](#joins) for more details. | `broadcast` |
| `rowsInMemory` | INSERT or REPLACE<br /><br />Maximum number of rows to store in memory at once before flushing to disk during the segment generation process. Ignored for non-INSERT queries. In most cases, use the default value. You may need to override the default if you run into one of the [known issues](./known-issues.md) around memory usage. | 100,000 | | `rowsInMemory` | INSERT or REPLACE<br /><br />Maximum number of rows to store in memory at once before flushing to disk during the segment generation process. Ignored for non-INSERT queries. In most cases, use the default value. You may need to override the default if you run into one of the [known issues](./known-issues.md) around memory usage. | 100,000 |
| `segmentSortOrder` | INSERT or REPLACE<br /><br />Normally, Druid sorts rows in individual segments using `__time` first, followed by the [CLUSTERED BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in segments using this column list first, followed by the CLUSTERED BY order.<br /><br />You provide the column list as comma-separated values or as a JSON array in string form. If your query includes `__time`, then this list must begin with `__time`. For example, consider an INSERT query that uses `CLUSTERED BY country` and has `segmentSortOrder` set to `__time,city`. Within each time chunk, Druid assigns rows to segments based on `country`, and then within each of those segments, Druid sorts those rows by `__time` first, then `city`, then `country`. | empty list | | `segmentSortOrder` | INSERT or REPLACE<br /><br />Normally, Druid sorts rows in individual segments using `__time` first, followed by the [CLUSTERED BY](#clustered-by) clause. When you set `segmentSortOrder`, Druid sorts rows in segments using this column list first, followed by the CLUSTERED BY order.<br /><br />You provide the column list as comma-separated values or as a JSON array in string form. If your query includes `__time`, then this list must begin with `__time`. For example, consider an INSERT query that uses `CLUSTERED BY country` and has `segmentSortOrder` set to `__time,city`. Within each time chunk, Druid assigns rows to segments based on `country`, and then within each of those segments, Druid sorts those rows by `__time` first, then `city`, then `country`. | empty list |
| `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1.| 0 | | `maxParseExceptions`| SELECT, INSERT, REPLACE<br /><br />Maximum number of parse exceptions that are ignored while executing the query before it stops with `TooManyWarningsFault`. To ignore all the parse exceptions, set the value to -1.| 0 |
@ -604,6 +605,92 @@ The following table lists the context parameters for the MSQ task engine:
| `intermediateSuperSorterStorageMaxLocalBytes` | SELECT, INSERT, REPLACE<br /><br /> Whether to enable a byte limit on local storage for sorting's intermediate data. If that limit is crossed, the task fails with `ResourceLimitExceededException`.| `9223372036854775807` | | `intermediateSuperSorterStorageMaxLocalBytes` | SELECT, INSERT, REPLACE<br /><br /> Whether to enable a byte limit on local storage for sorting's intermediate data. If that limit is crossed, the task fails with `ResourceLimitExceededException`.| `9223372036854775807` |
| `maxInputBytesPerWorker` | Should be used in conjunction with taskAssignment `auto` mode. When dividing the input of a stage among the workers, this parameter determines the maximum size in bytes that are given to a single worker before the next worker is chosen. This parameter is only used as a guideline during input slicing, and does not guarantee that a the input cannot be larger. For example, we have 3 files. 3, 7, 12 GB each. then we would end up using 2 worker: worker 1 -> 3, 7 and worker 2 -> 12. This value is used for all stages in a query. | `10737418240` | | `maxInputBytesPerWorker` | Should be used in conjunction with taskAssignment `auto` mode. When dividing the input of a stage among the workers, this parameter determines the maximum size in bytes that are given to a single worker before the next worker is chosen. This parameter is only used as a guideline during input slicing, and does not guarantee that a the input cannot be larger. For example, we have 3 files. 3, 7, 12 GB each. then we would end up using 2 worker: worker 1 -> 3, 7 and worker 2 -> 12. This value is used for all stages in a query. | `10737418240` |
## Joins
Joins in multi-stage queries use one of two algorithms, based on the [context parameter](#context-parameters)
`sqlJoinAlgorithm`. This context parameter applies to the entire SQL statement, so it is not possible to mix different
join algorithms in the same query.
### Broadcast
Set `sqlJoinAlgorithm` to `broadcast`.
The default join algorithm for multi-stage queries is a broadcast hash join, which is similar to how
[joins are executed with native queries](../querying/query-execution.md#join). First, any adjacent joins are flattened
into a structure with a "base" input (the bottom-leftmost one) and other leaf inputs (the rest). Next, any subqueries
that are inputs the join (either base or other leafs) are planned into independent stages. Then, the non-base leaf
inputs are all connected as broadcast inputs to the "base" stage.
Together, all of these non-base leaf inputs must not exceed the [limit on broadcast table footprint](#limits). There
is no limit on the size of the base (leftmost) input.
Only LEFT JOIN, INNER JOIN, and CROSS JOIN are supported with with `broadcast`.
Join conditions, if present, must be equalities. It is not necessary to include a join condition; for example,
`CROSS JOIN` and comma join do not require join conditions.
As an example, the following statement has a single join chain where `orders` is the base input, and `products` and
`customers` are non-base leaf inputs. The query will first read `products` and `customers`, then broadcast both to
the stage that reads `orders`. That stage loads the broadcast inputs (`products` and `customers`) in memory, and walks
through `orders` row by row. The results are then aggregated and written to the table `orders_enriched`. The broadcast
inputs (`products` and `customers`) must fall under the limit on broadcast table footprint, but the base `orders` input
can be unlimited in size.
```
REPLACE INTO orders_enriched
OVERWRITE ALL
SELECT
orders.__time,
products.name AS product_name,
customers.name AS customer_name,
SUM(orders.amount) AS amount
FROM orders
LEFT JOIN products ON orders.product_id = products.id
LEFT JOIN customers ON orders.customer_id = customers.id
GROUP BY 1, 2
PARTITIONED BY HOUR
CLUSTERED BY product_name
```
### Sort-merge
Set `sqlJoinAlgorithm` to `sortMerge`.
Multi-stage queries can use a sort-merge join algorithm. With this algorithm, each pairwise join is planned into its own
stage with two inputs. The two inputs are partitioned and sorted using a hash partitioning on the same key. This
approach is generally less performant, but more scalable, than `broadcast`. There are various scenarios where broadcast
join would return a [`BroadcastTablesTooLarge`](#errors) error, but a sort-merge join would succeed.
There is no limit on the overall size of either input, so sort-merge is a good choice for performing a join of two large
inputs, or for performing a self-join of a large input with itself.
There is a limit on the amount of data associated with each individual key. If _both_ sides of the join exceed this
limit, the query returns a [`TooManyRowsWithSameKey`](#errors) error. If only one side exceeds the limit, the query
does not return this error.
Join conditions, if present, must be equalities. It is not necessary to include a join condition; for example,
`CROSS JOIN` and comma join do not require join conditions.
All join types are supported with `sortMerge`: LEFT, RIGHT, INNER, FULL, and CROSS.
As an example, the following statement runs using a single sort-merge join stage that receives `eventstream`
(partitioned on `user_id`) and `users` (partitioned on `id`) as inputs. There is no limit on the size of either input.
```
REPLACE INTO eventstream_enriched
OVERWRITE ALL
SELECT
eventstream.__time,
eventstream.user_id,
eventstream.event_type,
eventstream.event_details,
users.signup_date AS user_signup_date
FROM eventstream
LEFT JOIN users ON eventstream.user_id = users.id
PARTITIONED BY HOUR
CLUSTERED BY user
```
## Sketch Merging Mode ## Sketch Merging Mode
This section details the advantages and performance of various Cluster By Statistics Merge Modes. This section details the advantages and performance of various Cluster By Statistics Merge Modes.
@ -656,6 +743,7 @@ The following table lists query limits:
| Number of cluster by columns that can appear in a stage | 1,500 | [`TooManyClusteredByColumns`](#error_TooManyClusteredByColumns) | | Number of cluster by columns that can appear in a stage | 1,500 | [`TooManyClusteredByColumns`](#error_TooManyClusteredByColumns) |
| Number of workers for any one stage. | Hard limit is 1,000. Memory-dependent soft limit may be lower. | [`TooManyWorkers`](#error_TooManyWorkers) | | Number of workers for any one stage. | Hard limit is 1,000. Memory-dependent soft limit may be lower. | [`TooManyWorkers`](#error_TooManyWorkers) |
| Maximum memory occupied by broadcasted tables. | 30% of each [processor memory bundle](concepts.md#memory-usage). | [`BroadcastTablesTooLarge`](#error_BroadcastTablesTooLarge) | | Maximum memory occupied by broadcasted tables. | 30% of each [processor memory bundle](concepts.md#memory-usage). | [`BroadcastTablesTooLarge`](#error_BroadcastTablesTooLarge) |
| Maximum memory occupied by buffered data during sort-merge join. Only relevant when `sqlJoinAlgorithm` is `sortMerge`. | 10 MB | `TooManyRowsWithSameKey` |
| Maximum relaunch attempts per worker. Initial run is not a relaunch. The worker will be spawned 1 + `workerRelaunchLimit` times before the job fails. | 2 | `TooManyAttemptsForWorker` | | Maximum relaunch attempts per worker. Initial run is not a relaunch. The worker will be spawned 1 + `workerRelaunchLimit` times before the job fails. | 2 | `TooManyAttemptsForWorker` |
| Maximum relaunch attempts for a job across all workers. | 100 | `TooManyAttemptsForJob` | | Maximum relaunch attempts for a job across all workers. | 100 | `TooManyAttemptsForJob` |
<a name="errors"></a> <a name="errors"></a>
@ -687,6 +775,7 @@ The following table describes error codes you may encounter in the `multiStageQu
| <a name="error_TooManyInputFiles">`TooManyInputFiles`</a> | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).<br /><br />If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.<br /><br />`maxInputFiles`: The maximum number of input files/segments per worker per stage.<br /><br />`minNumWorker`: The minimum number of workers required for a successful run. | | <a name="error_TooManyInputFiles">`TooManyInputFiles`</a> | Exceeded the maximum number of input files or segments per worker (10,000 files or segments).<br /><br />If you encounter this limit, consider adding more workers, or breaking up your query into smaller queries that process fewer files or segments per query. | `numInputFiles`: The total number of input files/segments for the stage.<br /><br />`maxInputFiles`: The maximum number of input files/segments per worker per stage.<br /><br />`minNumWorker`: The minimum number of workers required for a successful run. |
| <a name="error_TooManyPartitions">`TooManyPartitions`</a> | Exceeded the maximum number of partitions for a stage (25,000 partitions).<br /><br />This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded | | <a name="error_TooManyPartitions">`TooManyPartitions`</a> | Exceeded the maximum number of partitions for a stage (25,000 partitions).<br /><br />This can occur with INSERT or REPLACE statements that generate large numbers of segments, since each segment is associated with a partition. If you encounter this limit, consider breaking up your INSERT or REPLACE statement into smaller statements that process less data per statement. | `maxPartitions`: The limit on partitions which was exceeded |
| <a name="error_TooManyClusteredByColumns">`TooManyClusteredByColumns`</a> | Exceeded the maximum number of clustering columns for a stage (1,500 columns).<br /><br />This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP BY` with a large number of columns. | `numColumns`: The number of columns requested.<br /><br />`maxColumns`: The limit on columns which was exceeded.`stage`: The stage number exceeding the limit<br /><br /> | | <a name="error_TooManyClusteredByColumns">`TooManyClusteredByColumns`</a> | Exceeded the maximum number of clustering columns for a stage (1,500 columns).<br /><br />This can occur with `CLUSTERED BY`, `ORDER BY`, or `GROUP BY` with a large number of columns. | `numColumns`: The number of columns requested.<br /><br />`maxColumns`: The limit on columns which was exceeded.`stage`: The stage number exceeding the limit<br /><br /> |
| <a name="error_TooManyRowsWithSameKey">`TooManyRowsWithSameKey`</a> | The number of rows for a given key exceeded the maximum number of buffered bytes on both sides of a join. See the [Limits](#limits) table for the specific limit. Only occurs when `sqlJoinAlgorithm` is `sortMerge`. | `key`: The key that had a large number of rows.<br /><br />`numBytes`: Number of bytes buffered, which may include other keys.<br /><br />`maxBytes`: Maximum number of bytes buffered. |
| <a name="error_TooManyColumns">`TooManyColumns`</a> | Exceeded the maximum number of columns for a stage (2,000 columns). | `numColumns`: The number of columns requested.<br /><br />`maxColumns`: The limit on columns which was exceeded. | | <a name="error_TooManyColumns">`TooManyColumns`</a> | Exceeded the maximum number of columns for a stage (2,000 columns). | `numColumns`: The number of columns requested.<br /><br />`maxColumns`: The limit on columns which was exceeded. |
| <a name="error_TooManyWarnings">`TooManyWarnings`</a> | Exceeded the maximum allowed number of warnings of a particular type. | `rootErrorCode`: The error code corresponding to the exception that exceeded the required limit. <br /><br />`maxWarnings`: Maximum number of warnings that are allowed for the corresponding `rootErrorCode`. | | <a name="error_TooManyWarnings">`TooManyWarnings`</a> | Exceeded the maximum allowed number of warnings of a particular type. | `rootErrorCode`: The error code corresponding to the exception that exceeded the required limit. <br /><br />`maxWarnings`: Maximum number of warnings that are allowed for the corresponding `rootErrorCode`. |
| <a name="error_TooManyWorkers">`TooManyWorkers`</a> | Exceeded the maximum number of simultaneously-running workers. See the [Limits](#limits) table for more details. | `workers`: The number of simultaneously running workers that exceeded a hard or soft limit. This may be larger than the number of workers in any one stage if multiple stages are running simultaneously. <br /><br />`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is lower than the hard limit (1,000 workers), then you can increase the limit by adding more memory to each task. | | <a name="error_TooManyWorkers">`TooManyWorkers`</a> | Exceeded the maximum number of simultaneously-running workers. See the [Limits](#limits) table for more details. | `workers`: The number of simultaneously running workers that exceeded a hard or soft limit. This may be larger than the number of workers in any one stage if multiple stages are running simultaneously. <br /><br />`maxWorkers`: The hard or soft limit on workers that was exceeded. If this is lower than the hard limit (1,000 workers), then you can increase the limit by adding more memory to each task. |

View File

@ -289,10 +289,10 @@ GROUP BY
Join datasources allow you to do a SQL-style join of two datasources. Stacking joins on top of each other allows Join datasources allow you to do a SQL-style join of two datasources. Stacking joins on top of each other allows
you to join arbitrarily many datasources. you to join arbitrarily many datasources.
In Druid {{DRUIDVERSION}}, joins are implemented with a broadcast hash-join algorithm. This means that all datasources In Druid {{DRUIDVERSION}}, joins in native queries are implemented with a broadcast hash-join algorithm. This means
other than the leftmost "base" datasource must fit in memory. It also means that the join condition must be an equality. This that all datasources other than the leftmost "base" datasource must fit in memory. It also means that the join condition
feature is intended mainly to allow joining regular Druid tables with [lookup](#lookup), [inline](#inline), and must be an equality. This feature is intended mainly to allow joining regular Druid tables with [lookup](#lookup),
[query](#query) datasources. [inline](#inline), and [query](#query) datasources.
Refer to the [Query execution](query-execution.md#join) page for more details on how queries are executed when you Refer to the [Query execution](query-execution.md#join) page for more details on how queries are executed when you
use join datasources. use join datasources.
@ -362,13 +362,11 @@ Also, as a result of this, comma joins should be avoided.
Joins are an area of active development in Druid. The following features are missing today but may appear in Joins are an area of active development in Druid. The following features are missing today but may appear in
future versions: future versions:
- Reordering of predicates and filters (pushing up and/or pushing down) to get the most performant plan. - Reordering of join operations to get the most performant plan.
- Preloaded dimension tables that are wider than lookups (i.e. supporting more than a single key and single value). - Preloaded dimension tables that are wider than lookups (i.e. supporting more than a single key and single value).
- RIGHT OUTER and FULL OUTER joins. Currently, they are partially implemented. Queries will run but results will not - RIGHT OUTER and FULL OUTER joins in the native query engine. Currently, they are partially implemented. Queries run
always be correct. but results are not always correct.
- Performance-related optimizations as mentioned in the [previous section](#join-performance). - Performance-related optimizations as mentioned in the [previous section](#join-performance).
- Join algorithms other than broadcast hash-joins.
- Join condition on a column compared to a constant value.
- Join conditions on a column containing a multi-value dimension. - Join conditions on a column containing a multi-value dimension.
### `unnest` ### `unnest`

View File

@ -26,7 +26,9 @@ Apache Druid has two features related to joining of data:
1. [Join](datasource.md#join) operators. These are available using a [join datasource](datasource.md#join) in native 1. [Join](datasource.md#join) operators. These are available using a [join datasource](datasource.md#join) in native
queries, or using the [JOIN operator](sql.md) in Druid SQL. Refer to the queries, or using the [JOIN operator](sql.md) in Druid SQL. Refer to the
[join datasource](datasource.md#join) documentation for information about how joins work in Druid. [join datasource](datasource.md#join) documentation for information about how joins work in Druid native queries,
or the [multi-stage query join documentation](../multi-stage-query/reference.md#joins) for information about how joins
work in multi-stage query tasks.
2. [Query-time lookups](lookups.md), simple key-to-value mappings. These are preloaded on all servers that are involved 2. [Query-time lookups](lookups.md), simple key-to-value mappings. These are preloaded on all servers that are involved
in queries and can be accessed with or without an explicit join operator. Refer to the [lookups](lookups.md) in queries and can be accessed with or without an explicit join operator. Refer to the [lookups](lookups.md)
documentation for more details. documentation for more details.

View File

@ -47,9 +47,10 @@ import org.apache.druid.frame.channel.FrameChannelSequence;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.RowKeyReader; import org.apache.druid.frame.key.RowKeyReader;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.frame.processor.FrameProcessorExecutor; import org.apache.druid.frame.processor.FrameProcessorExecutor;
import org.apache.druid.frame.processor.FrameProcessors; import org.apache.druid.frame.processor.FrameProcessors;
import org.apache.druid.frame.util.DurableStorageUtils; import org.apache.druid.frame.util.DurableStorageUtils;
@ -132,12 +133,12 @@ import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.input.stage.StageInputSpecSlicer; import org.apache.druid.msq.input.stage.StageInputSpecSlicer;
import org.apache.druid.msq.input.table.TableInputSpec; import org.apache.druid.msq.input.table.TableInputSpec;
import org.apache.druid.msq.input.table.TableInputSpecSlicer; import org.apache.druid.msq.input.table.TableInputSpecSlicer;
import org.apache.druid.msq.kernel.GlobalSortTargetSizeShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.StagePartition; import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.msq.kernel.TargetSizeShuffleSpec;
import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.kernel.controller.ControllerStagePhase; import org.apache.druid.msq.kernel.controller.ControllerStagePhase;
@ -698,8 +699,8 @@ public class ControllerImpl implements Controller
final StageDefinition stageDef = queryKernel.getStageDefinition(stageId); final StageDefinition stageDef = queryKernel.getStageDefinition(stageId);
final ObjectMapper mapper = MSQTasks.decorateObjectMapperForKeyCollectorSnapshot( final ObjectMapper mapper = MSQTasks.decorateObjectMapperForKeyCollectorSnapshot(
context.jsonMapper(), context.jsonMapper(),
stageDef.getShuffleSpec().get().getClusterBy(), stageDef.getShuffleSpec().clusterBy(),
stageDef.getShuffleSpec().get().doesAggregateByClusterKey() stageDef.getShuffleSpec().doesAggregate()
); );
final PartialKeyStatisticsInformation partialKeyStatisticsInformation; final PartialKeyStatisticsInformation partialKeyStatisticsInformation;
@ -1502,7 +1503,7 @@ public class ControllerImpl implements Controller
if (MSQControllerTask.isIngestion(querySpec)) { if (MSQControllerTask.isIngestion(querySpec)) {
shuffleSpecFactory = (clusterBy, aggregate) -> shuffleSpecFactory = (clusterBy, aggregate) ->
new TargetSizeShuffleSpec( new GlobalSortTargetSizeShuffleSpec(
clusterBy, clusterBy,
tuningConfig.getRowsPerSegment(), tuningConfig.getRowsPerSegment(),
aggregate aggregate
@ -1728,7 +1729,7 @@ public class ControllerImpl implements Controller
final ColumnMappings columnMappings final ColumnMappings columnMappings
) )
{ {
final List<SortColumn> clusterByColumns = clusterBy.getColumns(); final List<KeyColumn> clusterByColumns = clusterBy.getColumns();
final List<String> shardColumns = new ArrayList<>(); final List<String> shardColumns = new ArrayList<>();
final boolean boosted = isClusterByBoosted(clusterBy); final boolean boosted = isClusterByBoosted(clusterBy);
final int numShardColumns = clusterByColumns.size() - clusterBy.getBucketByCount() - (boosted ? 1 : 0); final int numShardColumns = clusterByColumns.size() - clusterBy.getBucketByCount() - (boosted ? 1 : 0);
@ -1738,11 +1739,11 @@ public class ControllerImpl implements Controller
} }
for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) { for (int i = clusterBy.getBucketByCount(); i < clusterBy.getBucketByCount() + numShardColumns; i++) {
final SortColumn column = clusterByColumns.get(i); final KeyColumn column = clusterByColumns.get(i);
final List<String> outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName()); final List<String> outputColumns = columnMappings.getOutputColumnsForQueryColumn(column.columnName());
// DimensionRangeShardSpec only handles ascending order. // DimensionRangeShardSpec only handles ascending order.
if (column.descending()) { if (column.order() != KeyOrder.ASCENDING) {
return Collections.emptyList(); return Collections.emptyList();
} }
@ -1824,8 +1825,8 @@ public class ControllerImpl implements Controller
// Note: this doesn't work when CLUSTERED BY specifies an expression that is not being selected. // Note: this doesn't work when CLUSTERED BY specifies an expression that is not being selected.
// Such fields in CLUSTERED BY still control partitioning as expected, but do not affect sort order of rows // Such fields in CLUSTERED BY still control partitioning as expected, but do not affect sort order of rows
// within an individual segment. // within an individual segment.
for (final SortColumn clusterByColumn : queryClusterBy.getColumns()) { for (final KeyColumn clusterByColumn : queryClusterBy.getColumns()) {
if (clusterByColumn.descending()) { if (clusterByColumn.order() == KeyOrder.DESCENDING) {
throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName())); throw new MSQException(new InsertCannotOrderByDescendingFault(clusterByColumn.columnName()));
} }
@ -2400,7 +2401,7 @@ public class ControllerImpl implements Controller
segmentsToGenerate = generateSegmentIdsWithShardSpecs( segmentsToGenerate = generateSegmentIdsWithShardSpecs(
(DataSourceMSQDestination) task.getQuerySpec().getDestination(), (DataSourceMSQDestination) task.getQuerySpec().getDestination(),
queryKernel.getStageDefinition(shuffleStageId).getSignature(), queryKernel.getStageDefinition(shuffleStageId).getSignature(),
queryKernel.getStageDefinition(shuffleStageId).getShuffleSpec().get().getClusterBy(), queryKernel.getStageDefinition(shuffleStageId).getClusterBy(),
partitionBoundaries, partitionBoundaries,
mayHaveMultiValuedClusterByFields mayHaveMultiValuedClusterByFields
); );

View File

@ -69,6 +69,12 @@ public class Limits
*/ */
public static final int MAX_KERNEL_MANIPULATION_QUEUE_SIZE = 100_000; public static final int MAX_KERNEL_MANIPULATION_QUEUE_SIZE = 100_000;
/**
* Maximum number of bytes buffered for each side of a
* {@link org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessor}, not counting the most recent frame read.
*/
public static final int MAX_BUFFERED_BYTES_FOR_SORT_MERGE_JOIN = 10_000_000;
/** /**
* Maximum relaunches across all workers. * Maximum relaunches across all workers.
*/ */

View File

@ -22,6 +22,7 @@ package org.apache.druid.msq.exec;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Injector; import com.google.inject.Injector;
import org.apache.druid.frame.processor.Bouncer; import org.apache.druid.frame.processor.Bouncer;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.FrameContext;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
@ -73,4 +74,9 @@ public interface WorkerContext
DruidNode selfNode(); DruidNode selfNode();
Bouncer processorBouncer(); Bouncer processorBouncer();
default File tempDir(int stageNumber, String id)
{
return new File(StringUtils.format("%s/stage_%02d/%s", tempDir(), stageNumber, id));
}
} }

View File

@ -31,6 +31,7 @@ import org.apache.druid.msq.indexing.error.NotEnoughMemoryFault;
import org.apache.druid.msq.indexing.error.TooManyWorkersFault; import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
import org.apache.druid.msq.input.InputSpecs; import org.apache.druid.msq.input.InputSpecs;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl; import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.query.lookup.LookupExtractor; import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainer; import org.apache.druid.query.lookup.LookupExtractorFactoryContainer;
@ -51,7 +52,7 @@ import java.util.Objects;
* entirely on server configuration; this makes the calculation robust to different queries running simultaneously in * entirely on server configuration; this makes the calculation robust to different queries running simultaneously in
* the same JVM. * the same JVM.
* *
* Then, we split up the resources for each bundle in two different ways: one assuming it'll be used for a * Within each bundle, we split up memory in two different ways: one assuming it'll be used for a
* {@link org.apache.druid.frame.processor.SuperSorter}, and one assuming it'll be used for a regular * {@link org.apache.druid.frame.processor.SuperSorter}, and one assuming it'll be used for a regular
* processor. Callers can then use whichever set of allocations makes sense. (We assume no single bundle * processor. Callers can then use whichever set of allocations makes sense. (We assume no single bundle
* will be used for both purposes.) * will be used for both purposes.)
@ -166,6 +167,7 @@ public class WorkerMemoryParameters
computeNumWorkersInJvm(injector), computeNumWorkersInJvm(injector),
computeNumProcessorsInJvm(injector), computeNumProcessorsInJvm(injector),
0, 0,
0,
totalLookupFootprint totalLookupFootprint
); );
} }
@ -179,19 +181,27 @@ public class WorkerMemoryParameters
final int stageNumber final int stageNumber
) )
{ {
final IntSet inputStageNumbers = final StageDefinition stageDef = queryDef.getStageDefinition(stageNumber);
InputSpecs.getStageNumbers(queryDef.getStageDefinition(stageNumber).getInputSpecs()); final IntSet inputStageNumbers = InputSpecs.getStageNumbers(stageDef.getInputSpecs());
final int numInputWorkers = final int numInputWorkers =
inputStageNumbers.intStream() inputStageNumbers.intStream()
.map(inputStageNumber -> queryDef.getStageDefinition(inputStageNumber).getMaxWorkerCount()) .map(inputStageNumber -> queryDef.getStageDefinition(inputStageNumber).getMaxWorkerCount())
.sum(); .sum();
long totalLookupFootprint = computeTotalLookupFootprint(injector); long totalLookupFootprint = computeTotalLookupFootprint(injector);
final int numHashOutputPartitions;
if (stageDef.doesShuffle() && stageDef.getShuffleSpec().kind().isHash()) {
numHashOutputPartitions = stageDef.getShuffleSpec().partitionCount();
} else {
numHashOutputPartitions = 0;
}
return createInstance( return createInstance(
Runtime.getRuntime().maxMemory(), Runtime.getRuntime().maxMemory(),
computeNumWorkersInJvm(injector), computeNumWorkersInJvm(injector),
computeNumProcessorsInJvm(injector), computeNumProcessorsInJvm(injector),
numInputWorkers, numInputWorkers,
numHashOutputPartitions,
totalLookupFootprint totalLookupFootprint
); );
} }
@ -206,15 +216,18 @@ public class WorkerMemoryParameters
* @param numWorkersInJvm number of workers that can run concurrently in this JVM. Generally equal to * @param numWorkersInJvm number of workers that can run concurrently in this JVM. Generally equal to
* the task capacity. * the task capacity.
* @param numProcessingThreadsInJvm size of the processing thread pool in the JVM. * @param numProcessingThreadsInJvm size of the processing thread pool in the JVM.
* @param numInputWorkers number of workers across input stages that need to be merged together. * @param numInputWorkers total number of workers across all input stages.
* @param totalLookUpFootprint estimated size of the lookups loaded by the process. * @param numHashOutputPartitions total number of output partitions, if using hash partitioning; zero if not using
* hash partitioning.
* @param totalLookupFootprint estimated size of the lookups loaded by the process.
*/ */
public static WorkerMemoryParameters createInstance( public static WorkerMemoryParameters createInstance(
final long maxMemoryInJvm, final long maxMemoryInJvm,
final int numWorkersInJvm, final int numWorkersInJvm,
final int numProcessingThreadsInJvm, final int numProcessingThreadsInJvm,
final int numInputWorkers, final int numInputWorkers,
final long totalLookUpFootprint final int numHashOutputPartitions,
final long totalLookupFootprint
) )
{ {
Preconditions.checkArgument(maxMemoryInJvm > 0, "Max memory passed: [%s] should be > 0", maxMemoryInJvm); Preconditions.checkArgument(maxMemoryInJvm > 0, "Max memory passed: [%s] should be > 0", maxMemoryInJvm);
@ -226,18 +239,25 @@ public class WorkerMemoryParameters
); );
Preconditions.checkArgument(numInputWorkers >= 0, "Number of input workers: [%s] should be >=0", numInputWorkers); Preconditions.checkArgument(numInputWorkers >= 0, "Number of input workers: [%s] should be >=0", numInputWorkers);
Preconditions.checkArgument( Preconditions.checkArgument(
totalLookUpFootprint >= 0, totalLookupFootprint >= 0,
"Lookup memory footprint: [%s] should be >= 0", "Lookup memory footprint: [%s] should be >= 0",
totalLookUpFootprint totalLookupFootprint
); );
final long usableMemoryInJvm = computeUsableMemoryInJvm(maxMemoryInJvm, totalLookUpFootprint); final long usableMemoryInJvm = computeUsableMemoryInJvm(maxMemoryInJvm, totalLookupFootprint);
final long workerMemory = memoryPerWorker(usableMemoryInJvm, numWorkersInJvm); final long workerMemory = memoryPerWorker(usableMemoryInJvm, numWorkersInJvm);
final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
final long bundleMemoryForInputChannels = memoryNeededForInputChannels(numInputWorkers); final long bundleMemoryForInputChannels = memoryNeededForInputChannels(numInputWorkers);
final long bundleMemoryForProcessing = bundleMemory - bundleMemoryForInputChannels; final long bundleMemoryForHashPartitioning = memoryNeededForHashPartitioning(numHashOutputPartitions);
final long bundleMemoryForProcessing =
bundleMemory - bundleMemoryForInputChannels - bundleMemoryForHashPartitioning;
if (bundleMemoryForProcessing < PROCESSING_MINIMUM_BYTES) { if (bundleMemoryForProcessing < PROCESSING_MINIMUM_BYTES) {
final int maxWorkers = computeMaxWorkers(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); final int maxWorkers = computeMaxWorkers(
usableMemoryInJvm,
numWorkersInJvm,
numProcessingThreadsInJvm,
numHashOutputPartitions
);
if (maxWorkers > 0) { if (maxWorkers > 0) {
throw new MSQException(new TooManyWorkersFault(numInputWorkers, Math.min(Limits.MAX_WORKERS, maxWorkers))); throw new MSQException(new TooManyWorkersFault(numInputWorkers, Math.min(Limits.MAX_WORKERS, maxWorkers)));
@ -250,7 +270,7 @@ public class WorkerMemoryParameters
numWorkersInJvm, numWorkersInJvm,
numProcessingThreadsInJvm, numProcessingThreadsInJvm,
PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels PROCESSING_MINIMUM_BYTES + BUFFER_BYTES_FOR_ESTIMATION + bundleMemoryForInputChannels
), totalLookUpFootprint), ), totalLookupFootprint),
maxMemoryInJvm, maxMemoryInJvm,
usableMemoryInJvm, usableMemoryInJvm,
numWorkersInJvm, numWorkersInJvm,
@ -271,7 +291,7 @@ public class WorkerMemoryParameters
numWorkersInJvm, numWorkersInJvm,
(MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE (MIN_SUPER_SORTER_FRAMES + BUFFER_BYTES_FOR_ESTIMATION) * LARGE_FRAME_SIZE
), ),
totalLookUpFootprint totalLookupFootprint
), ),
maxMemoryInJvm, maxMemoryInJvm,
usableMemoryInJvm, usableMemoryInJvm,
@ -393,13 +413,19 @@ public class WorkerMemoryParameters
static int computeMaxWorkers( static int computeMaxWorkers(
final long usableMemoryInJvm, final long usableMemoryInJvm,
final int numWorkersInJvm, final int numWorkersInJvm,
final int numProcessingThreadsInJvm final int numProcessingThreadsInJvm,
final int numHashOutputPartitions
) )
{ {
final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm); final long bundleMemory = memoryPerBundle(usableMemoryInJvm, numWorkersInJvm, numProcessingThreadsInJvm);
// Inverse of memoryNeededForInputChannels. // Compute number of workers that gives us PROCESSING_MINIMUM_BYTES of memory per bundle, while accounting for
return Math.max(0, Ints.checkedCast((bundleMemory - PROCESSING_MINIMUM_BYTES) / STANDARD_FRAME_SIZE - 1)); // memoryNeededForInputChannels + memoryNeededForHashPartitioning.
final int isHashing = numHashOutputPartitions > 0 ? 1 : 0;
return Math.max(
0,
Ints.checkedCast((bundleMemory - PROCESSING_MINIMUM_BYTES) / ((long) STANDARD_FRAME_SIZE * (1 + isHashing)) - 1)
);
} }
/** /**
@ -499,17 +525,29 @@ public class WorkerMemoryParameters
return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1); return (long) STANDARD_FRAME_SIZE * (numInputWorkers + 1);
} }
/** private static long memoryNeededForHashPartitioning(final int numOutputPartitions)
* Amount of heap memory available for our usage. Any computation changes done to this method should also be done in its corresponding method {@link WorkerMemoryParameters#calculateSuggestedMinMemoryFromUsableMemory}
*/
private static long computeUsableMemoryInJvm(final long maxMemory, final long totalLookupFootprint)
{ {
// since lookups are essentially in memory hashmap's, the object overhead is trivial hence its subtracted prior to usable memory calculations. // One standard frame for each processor output.
return (long) ((maxMemory - totalLookupFootprint) * USABLE_MEMORY_FRACTION); // May be zero, since numOutputPartitions is zero if not using hash partitioning.
return (long) STANDARD_FRAME_SIZE * numOutputPartitions;
} }
/** /**
* Estimate amount of heap memory for the given workload to use in case usable memory is provided. This method is used for better exception messages when {@link NotEnoughMemoryFault} is thrown. * Amount of heap memory available for our usage. Any computation changes done to this method should also be done in
* its corresponding method {@link WorkerMemoryParameters#calculateSuggestedMinMemoryFromUsableMemory}
*/
private static long computeUsableMemoryInJvm(final long maxMemory, final long totalLookupFootprint)
{
// Always report at least one byte, to simplify the math in createInstance.
return Math.max(
1,
(long) ((maxMemory - totalLookupFootprint) * USABLE_MEMORY_FRACTION)
);
}
/**
* Estimate amount of heap memory for the given workload to use in case usable memory is provided. This method is used
* for better exception messages when {@link NotEnoughMemoryFault} is thrown.
*/ */
private static long calculateSuggestedMinMemoryFromUsableMemory(long usuableMemeory, final long totalLookupFootprint) private static long calculateSuggestedMinMemoryFromUsableMemory(long usuableMemeory, final long totalLookupFootprint)
{ {

View File

@ -267,7 +267,6 @@ public class WorkerSketchFetcher implements AutoCloseable
), ),
retryOperation retryOperation
); );
}); });
} }
} }

View File

@ -61,6 +61,7 @@ import org.apache.druid.msq.indexing.error.TooManyClusteredByColumnsFault;
import org.apache.druid.msq.indexing.error.TooManyColumnsFault; import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
import org.apache.druid.msq.indexing.error.TooManyInputFilesFault; import org.apache.druid.msq.indexing.error.TooManyInputFilesFault;
import org.apache.druid.msq.indexing.error.TooManyPartitionsFault; import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault;
import org.apache.druid.msq.indexing.error.TooManyWarningsFault; import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
import org.apache.druid.msq.indexing.error.TooManyWorkersFault; import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
import org.apache.druid.msq.indexing.error.UnknownFault; import org.apache.druid.msq.indexing.error.UnknownFault;
@ -78,6 +79,7 @@ import org.apache.druid.msq.input.table.TableInputSpec;
import org.apache.druid.msq.kernel.NilExtraInfoHolder; import org.apache.druid.msq.kernel.NilExtraInfoHolder;
import org.apache.druid.msq.querykit.InputNumberDataSource; import org.apache.druid.msq.querykit.InputNumberDataSource;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
import org.apache.druid.msq.querykit.groupby.GroupByPostShuffleFrameProcessorFactory; import org.apache.druid.msq.querykit.groupby.GroupByPostShuffleFrameProcessorFactory;
import org.apache.druid.msq.querykit.groupby.GroupByPreShuffleFrameProcessorFactory; import org.apache.druid.msq.querykit.groupby.GroupByPreShuffleFrameProcessorFactory;
import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessorFactory; import org.apache.druid.msq.querykit.scan.ScanQueryFrameProcessorFactory;
@ -118,6 +120,7 @@ public class MSQIndexingModule implements DruidModule
TooManyColumnsFault.class, TooManyColumnsFault.class,
TooManyInputFilesFault.class, TooManyInputFilesFault.class,
TooManyPartitionsFault.class, TooManyPartitionsFault.class,
TooManyRowsWithSameKeyFault.class,
TooManyWarningsFault.class, TooManyWarningsFault.class,
TooManyWorkersFault.class, TooManyWorkersFault.class,
TooManyAttemptsForJob.class, TooManyAttemptsForJob.class,
@ -150,6 +153,7 @@ public class MSQIndexingModule implements DruidModule
ScanQueryFrameProcessorFactory.class, ScanQueryFrameProcessorFactory.class,
GroupByPreShuffleFrameProcessorFactory.class, GroupByPreShuffleFrameProcessorFactory.class,
GroupByPostShuffleFrameProcessorFactory.class, GroupByPostShuffleFrameProcessorFactory.class,
SortMergeJoinFrameProcessorFactory.class,
OffsetLimitFrameProcessorFactory.class, OffsetLimitFrameProcessorFactory.class,
NilExtraInfoHolder.class, NilExtraInfoHolder.class,

View File

@ -63,6 +63,12 @@ public class CountingWritableFrameChannel implements WritableFrameChannel
baseChannel.close(); baseChannel.close();
} }
@Override
public boolean isClosed()
{
return baseChannel.isClosed();
}
@Override @Override
public ListenableFuture<?> writabilityFuture() public ListenableFuture<?> writabilityFuture()
{ {

View File

@ -22,11 +22,12 @@ package org.apache.druid.msq.indexing;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import org.apache.druid.frame.FrameType; import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
import org.apache.druid.frame.channel.BlockingQueueFrameChannel; import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.processor.FrameChannelMerger; import org.apache.druid.frame.processor.FrameChannelMerger;
import org.apache.druid.frame.processor.FrameChannelMuxer; import org.apache.druid.frame.processor.FrameChannelMixer;
import org.apache.druid.frame.processor.FrameProcessorExecutor; import org.apache.druid.frame.processor.FrameProcessorExecutor;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.frame.write.FrameWriters;
@ -90,8 +91,8 @@ public class InputChannelsImpl implements InputChannels
{ {
final StageDefinition stageDef = queryDefinition.getStageDefinition(stagePartition.getStageId()); final StageDefinition stageDef = queryDefinition.getStageDefinition(stagePartition.getStageId());
final ReadablePartition readablePartition = readablePartitionMap.get(stagePartition); final ReadablePartition readablePartition = readablePartitionMap.get(stagePartition);
final ClusterBy inputClusterBy = stageDef.getClusterBy(); final ClusterBy clusterBy = stageDef.getClusterBy();
final boolean isSorted = inputClusterBy.getBucketByCount() != inputClusterBy.getColumns().size(); final boolean isSorted = clusterBy.sortable() && (clusterBy.getColumns().size() - clusterBy.getBucketByCount() > 0);
if (isSorted) { if (isSorted) {
return openSorted(stageDef, readablePartition); return openSorted(stageDef, readablePartition);
@ -129,13 +130,13 @@ public class InputChannelsImpl implements InputChannels
queueChannel.writable(), queueChannel.writable(),
FrameWriters.makeFrameWriterFactory( FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED, FrameType.ROW_BASED,
allocatorMaker.get(), new SingleMemoryAllocatorFactory(allocatorMaker.get()),
stageDefinition.getFrameReader().signature(), stageDefinition.getFrameReader().signature(),
// No sortColumns, because FrameChannelMerger generates frames that are sorted all on its own // No sortColumns, because FrameChannelMerger generates frames that are sorted all on its own
Collections.emptyList() Collections.emptyList()
), ),
stageDefinition.getClusterBy(), stageDefinition.getSortKey(),
null, null,
-1 -1
); );
@ -163,7 +164,7 @@ public class InputChannelsImpl implements InputChannels
return Iterables.getOnlyElement(channels); return Iterables.getOnlyElement(channels);
} else { } else {
final BlockingQueueFrameChannel queueChannel = BlockingQueueFrameChannel.minimal(); final BlockingQueueFrameChannel queueChannel = BlockingQueueFrameChannel.minimal();
final FrameChannelMuxer muxer = new FrameChannelMuxer(channels, queueChannel.writable()); final FrameChannelMixer muxer = new FrameChannelMixer(channels, queueChannel.writable());
// Discard future, since there is no need to keep it. We aren't interested in its return value. If it fails, // Discard future, since there is no need to keep it. We aren't interested in its return value. If it fails,
// downstream processors are notified through fail(e) on in-memory channels. If we need to cancel it, we use // downstream processors are notified through fail(e) on in-memory channels. If we need to cancel it, we use

View File

@ -22,6 +22,8 @@ package org.apache.druid.msq.indexing.error;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import java.util.Objects; import java.util.Objects;
@ -35,9 +37,14 @@ public class BroadcastTablesTooLargeFault extends BaseMSQFault
@JsonCreator @JsonCreator
public BroadcastTablesTooLargeFault(@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize) public BroadcastTablesTooLargeFault(@JsonProperty("maxBroadcastTablesSize") final long maxBroadcastTablesSize)
{ {
super(CODE, super(
"Size of the broadcast tables exceed the memory reserved for them (memory reserved for broadcast tables = %d bytes)", CODE,
maxBroadcastTablesSize "Size of broadcast tables in JOIN exceeds reserved memory limit "
+ "(memory reserved for broadcast tables = %d bytes). "
+ "Increase available memory, or set %s: %s in query context to use a shuffle-based join.",
maxBroadcastTablesSize,
PlannerContext.CTX_SQL_JOIN_ALGORITHM,
JoinAlgorithm.SORT_MERGE.toString()
); );
this.maxBroadcastTablesSize = maxBroadcastTablesSize; this.maxBroadcastTablesSize = maxBroadcastTablesSize;
} }

View File

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.indexing.error;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import java.util.List;
import java.util.Objects;
@JsonTypeName(TooManyRowsWithSameKeyFault.CODE)
public class TooManyRowsWithSameKeyFault extends BaseMSQFault
{
static final String CODE = "TooManyRowsWithSameKey";
private final List<Object> key;
private final long numBytes;
private final long maxBytes;
@JsonCreator
public TooManyRowsWithSameKeyFault(
@JsonProperty("key") final List<Object> key,
@JsonProperty("numBytes") final long numBytes,
@JsonProperty("maxBytes") final long maxBytes
)
{
super(
CODE,
"Too many rows with the same key during sort-merge join (bytes buffered = %,d; limit = %,d). Key: %s",
numBytes,
maxBytes,
key
);
this.key = key;
this.numBytes = numBytes;
this.maxBytes = maxBytes;
}
@JsonProperty
public List<Object> getKey()
{
return key;
}
@JsonProperty
public long getNumBytes()
{
return numBytes;
}
@JsonProperty
public long getMaxBytes()
{
return maxBytes;
}
@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
TooManyRowsWithSameKeyFault that = (TooManyRowsWithSameKeyFault) o;
return numBytes == that.numBytes && maxBytes == that.maxBytes && Objects.equals(key, that.key);
}
@Override
public int hashCode()
{
return Objects.hash(super.hashCode(), key, numBytes, maxBytes);
}
}

View File

@ -19,12 +19,20 @@
package org.apache.druid.msq.input; package org.apache.druid.msq.input;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.channel.ReadableNilFrameChannel;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.counters.CounterTracker;
import org.apache.druid.msq.input.stage.ReadablePartitions; import org.apache.druid.msq.input.stage.ReadablePartitions;
import org.apache.druid.msq.input.stage.StageInputSlice; import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.StagePartition;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.function.Consumer;
public class InputSlices public class InputSlices
{ {
@ -33,6 +41,10 @@ public class InputSlices
// No instantiation. // No instantiation.
} }
/**
* Combines all {@link StageInputSlice#getPartitions()} from the input slices that are {@link StageInputSlice}.
* Ignores other types of input slices.
*/
public static ReadablePartitions allReadablePartitions(final List<InputSlice> slices) public static ReadablePartitions allReadablePartitions(final List<InputSlice> slices)
{ {
final List<ReadablePartitions> partitionsList = new ArrayList<>(); final List<ReadablePartitions> partitionsList = new ArrayList<>();
@ -46,6 +58,10 @@ public class InputSlices
return ReadablePartitions.combine(partitionsList); return ReadablePartitions.combine(partitionsList);
} }
/**
* Sum of {@link InputSliceReader#numReadableInputs(InputSlice)} across all input slices that are _not_ present
* in "broadcastInputs".
*/
public static int getNumNonBroadcastReadableInputs( public static int getNumNonBroadcastReadableInputs(
final List<InputSlice> slices, final List<InputSlice> slices,
final InputSliceReader reader, final InputSliceReader reader,
@ -62,4 +78,70 @@ public class InputSlices
return numInputs; return numInputs;
} }
/**
* Calls {@link InputSliceReader#attach} on all "slices", which must all be {@link NilInputSlice} or
* {@link StageInputSlice}, and collects like-numbered partitions.
*
* The returned map is keyed by partition number. Each value is a list of inputs of the
* same length as "slices", and in the same order. i.e., the first ReadableInput in each list corresponds to the
* first provided {@link InputSlice}.
*
* "Missing" partitions -- which occur when one slice has no data for a given partition -- are replaced with
* {@link ReadableInput} based on {@link ReadableNilFrameChannel}, with no {@link StagePartition}.
*
* @throws IllegalStateException if any slices are not {@link StageInputSlice}
*/
public static Int2ObjectMap<List<ReadableInput>> attachAndCollectPartitions(
final List<InputSlice> slices,
final InputSliceReader reader,
final CounterTracker counters,
final Consumer<Throwable> warningPublisher
)
{
// Input number -> ReadableInputs.
final List<ReadableInputs> inputsByInputNumber = new ArrayList<>();
for (final InputSlice slice : slices) {
if (slice instanceof NilInputSlice) {
inputsByInputNumber.add(null);
} else if (slice instanceof StageInputSlice) {
final ReadableInputs inputs = reader.attach(inputsByInputNumber.size(), slice, counters, warningPublisher);
inputsByInputNumber.add(inputs);
} else {
throw new ISE("Slice [%s] is not a 'stage' slice", slice);
}
}
// Populate the result map.
final Int2ObjectMap<List<ReadableInput>> retVal = new Int2ObjectRBTreeMap<>();
for (int inputNumber = 0; inputNumber < slices.size(); inputNumber++) {
for (final ReadableInput input : inputsByInputNumber.get(inputNumber)) {
if (input != null) {
final int partitionNumber = input.getStagePartition().getPartitionNumber();
retVal.computeIfAbsent(partitionNumber, ignored -> Arrays.asList(new ReadableInput[slices.size()]))
.set(inputNumber, input);
}
}
}
// Fill in all nulls with NilInputSlice.
for (Int2ObjectMap.Entry<List<ReadableInput>> entry : retVal.int2ObjectEntrySet()) {
for (int inputNumber = 0; inputNumber < entry.getValue().size(); inputNumber++) {
if (entry.getValue().get(inputNumber) == null) {
entry.getValue().set(
inputNumber,
ReadableInput.channel(
ReadableNilFrameChannel.INSTANCE,
inputsByInputNumber.get(inputNumber).frameReader(),
null
)
);
}
}
}
return retVal;
}
} }

View File

@ -40,7 +40,7 @@ public class ReadableInput
private final SegmentWithDescriptor segment; private final SegmentWithDescriptor segment;
@Nullable @Nullable
private final ReadableFrameChannel inputChannel; private final ReadableFrameChannel channel;
@Nullable @Nullable
private final FrameReader frameReader; private final FrameReader frameReader;
@ -56,7 +56,7 @@ public class ReadableInput
) )
{ {
this.segment = segment; this.segment = segment;
this.inputChannel = channel; this.channel = channel;
this.frameReader = frameReader; this.frameReader = frameReader;
this.stagePartition = stagePartition; this.stagePartition = stagePartition;
@ -65,48 +65,107 @@ public class ReadableInput
} }
} }
/**
* Create an input associated with a physical segment.
*
* @param segment the segment
*/
public static ReadableInput segment(final SegmentWithDescriptor segment) public static ReadableInput segment(final SegmentWithDescriptor segment)
{ {
return new ReadableInput(segment, null, null, null); return new ReadableInput(Preconditions.checkNotNull(segment, "segment"), null, null, null);
} }
/**
* Create an input associated with a channel.
*
* @param channel the channel
* @param frameReader reader for the channel
* @param stagePartition stage-partition associated with the channel, if meaningful. May be null if this channel
* does not correspond to any one particular stage-partition.
*/
public static ReadableInput channel( public static ReadableInput channel(
final ReadableFrameChannel inputChannel, final ReadableFrameChannel channel,
final FrameReader frameReader, final FrameReader frameReader,
final StagePartition stagePartition @Nullable final StagePartition stagePartition
) )
{ {
return new ReadableInput(null, inputChannel, frameReader, stagePartition); return new ReadableInput(
null,
Preconditions.checkNotNull(channel, "channel"),
Preconditions.checkNotNull(frameReader, "frameReader"),
stagePartition
);
} }
/**
* Whether this input is a segment (from {@link #segment(SegmentWithDescriptor)}.
*/
public boolean hasSegment() public boolean hasSegment()
{ {
return segment != null; return segment != null;
} }
/**
* Whether this input is a channel (from {@link #channel(ReadableFrameChannel, FrameReader, StagePartition)}.
*/
public boolean hasChannel() public boolean hasChannel()
{ {
return inputChannel != null; return channel != null;
} }
/**
* The segment for this input. Only valid if {@link #hasSegment()}.
*/
public SegmentWithDescriptor getSegment() public SegmentWithDescriptor getSegment()
{ {
return Preconditions.checkNotNull(segment, "segment"); checkIsSegment();
return segment;
} }
/**
* The channel for this input. Only valid if {@link #hasChannel()}.
*/
public ReadableFrameChannel getChannel() public ReadableFrameChannel getChannel()
{ {
return Preconditions.checkNotNull(inputChannel, "channel"); checkIsChannel();
return channel;
} }
/**
* The frame reader for this input. Only valid if {@link #hasChannel()}.
*/
public FrameReader getChannelFrameReader() public FrameReader getChannelFrameReader()
{ {
return Preconditions.checkNotNull(frameReader, "frameReader"); checkIsChannel();
return frameReader;
} }
@Nullable /**
* The stage-partition this input. Only valid if {@link #hasChannel()}, and if a stage-partition was provided
* during construction. Throws {@link IllegalStateException} if no stage-partition was provided during construction.
*/
public StagePartition getStagePartition() public StagePartition getStagePartition()
{ {
checkIsChannel();
if (stagePartition == null) {
throw new ISE("Stage-partition is not set for this channel");
}
return stagePartition; return stagePartition;
} }
private void checkIsSegment()
{
if (!hasSegment()) {
throw new ISE("Not a channel input; cannot call this method");
}
}
private void checkIsChannel()
{
if (!hasChannel()) {
throw new ISE("Not a channel input; cannot call this method");
}
}
} }

View File

@ -62,6 +62,8 @@ public class ReadableInputs implements Iterable<ReadableInput>
/** /**
* Returns the {@link ReadableInput} as an Iterator. * Returns the {@link ReadableInput} as an Iterator.
*
* When this instance is channel-based ({@link #isChannelBased()}), inputs are returned in order of partition number.
*/ */
@Override @Override
public Iterator<ReadableInput> iterator() public Iterator<ReadableInput> iterator()

View File

@ -27,6 +27,9 @@ import org.apache.druid.segment.Segment;
import java.io.Closeable; import java.io.Closeable;
import java.util.Objects; import java.util.Objects;
/**
* A holder for a physical segment.
*/
public class SegmentWithDescriptor implements Closeable public class SegmentWithDescriptor implements Closeable
{ {
private final ResourceHolder<? extends Segment> segmentHolder; private final ResourceHolder<? extends Segment> segmentHolder;
@ -41,22 +44,35 @@ public class SegmentWithDescriptor implements Closeable
this.descriptor = Preconditions.checkNotNull(descriptor, "descriptor"); this.descriptor = Preconditions.checkNotNull(descriptor, "descriptor");
} }
/**
* The physical segment.
*
* Named "getOrLoad" because the segment may be held by an eager or lazy resource holder (i.e.
* {@link org.apache.druid.msq.querykit.LazyResourceHolder}). If the resource holder is lazy, the segment is acquired
* as part of the call to this method.
*/
public Segment getOrLoadSegment() public Segment getOrLoadSegment()
{ {
return segmentHolder.get(); return segmentHolder.get();
} }
/**
* The segment descriptor associated with this physical segment.
*/
public SegmentDescriptor getDescriptor()
{
return descriptor;
}
/**
* Release resources used by the physical segment.
*/
@Override @Override
public void close() public void close()
{ {
segmentHolder.close(); segmentHolder.close();
} }
public SegmentDescriptor getDescriptor()
{
return descriptor;
}
@Override @Override
public boolean equals(Object o) public boolean equals(Object o)
{ {

View File

@ -27,6 +27,7 @@ import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -35,54 +36,71 @@ import java.util.Objects;
/** /**
* Shuffle spec that generates up to a certain number of output partitions. Commonly used for shuffles between stages. * Shuffle spec that generates up to a certain number of output partitions. Commonly used for shuffles between stages.
*/ */
public class MaxCountShuffleSpec implements ShuffleSpec public class GlobalSortMaxCountShuffleSpec implements GlobalSortShuffleSpec
{ {
public static final String TYPE = "maxCount";
private final ClusterBy clusterBy; private final ClusterBy clusterBy;
private final int partitions; private final int maxPartitions;
private final boolean aggregate; private final boolean aggregate;
@JsonCreator @JsonCreator
public MaxCountShuffleSpec( public GlobalSortMaxCountShuffleSpec(
@JsonProperty("clusterBy") final ClusterBy clusterBy, @JsonProperty("clusterBy") final ClusterBy clusterBy,
@JsonProperty("partitions") final int partitions, @JsonProperty("partitions") final int maxPartitions,
@JsonProperty("aggregate") final boolean aggregate @JsonProperty("aggregate") final boolean aggregate
) )
{ {
this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy"); this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy");
this.partitions = partitions; this.maxPartitions = maxPartitions;
this.aggregate = aggregate; this.aggregate = aggregate;
if (partitions < 1) { if (maxPartitions < 1) {
throw new IAE("Partition count must be at least 1"); throw new IAE("Partition count must be at least 1");
} }
if (!clusterBy.sortable()) {
throw new IAE("ClusterBy key must be sortable");
}
if (clusterBy.getBucketByCount() > 0) {
// Only GlobalSortTargetSizeShuffleSpec supports bucket-by.
throw new IAE("Cannot bucket with %s partitioning", TYPE);
}
}
@Override
public ShuffleKind kind()
{
return ShuffleKind.GLOBAL_SORT;
} }
@Override @Override
@JsonProperty("aggregate") @JsonProperty("aggregate")
@JsonInclude(JsonInclude.Include.NON_DEFAULT) @JsonInclude(JsonInclude.Include.NON_DEFAULT)
public boolean doesAggregateByClusterKey() public boolean doesAggregate()
{ {
return aggregate; return aggregate;
} }
@Override @Override
public boolean needsStatistics() public boolean mustGatherResultKeyStatistics()
{ {
return partitions > 1 || clusterBy.getBucketByCount() > 0; return maxPartitions > 1 || clusterBy.getBucketByCount() > 0;
} }
@Override @Override
public Either<Long, ClusterByPartitions> generatePartitions( public Either<Long, ClusterByPartitions> generatePartitionsForGlobalSort(
@Nullable final ClusterByStatisticsCollector collector, @Nullable final ClusterByStatisticsCollector collector,
final int maxNumPartitions final int maxNumPartitions
) )
{ {
if (!needsStatistics()) { if (!mustGatherResultKeyStatistics()) {
return Either.value(ClusterByPartitions.oneUniversalPartition()); return Either.value(ClusterByPartitions.oneUniversalPartition());
} else if (partitions > maxNumPartitions) { } else if (maxPartitions > maxNumPartitions) {
return Either.error((long) partitions); return Either.error((long) maxPartitions);
} else { } else {
final ClusterByPartitions generatedPartitions = collector.generatePartitionsWithMaxCount(partitions); final ClusterByPartitions generatedPartitions = collector.generatePartitionsWithMaxCount(maxPartitions);
if (generatedPartitions.size() <= maxNumPartitions) { if (generatedPartitions.size() <= maxNumPartitions) {
return Either.value(generatedPartitions); return Either.value(generatedPartitions);
} else { } else {
@ -93,15 +111,21 @@ public class MaxCountShuffleSpec implements ShuffleSpec
@Override @Override
@JsonProperty @JsonProperty
public ClusterBy getClusterBy() public ClusterBy clusterBy()
{ {
return clusterBy; return clusterBy;
} }
@JsonProperty @Override
int getPartitions() public int partitionCount()
{ {
return partitions; throw new ISE("Number of partitions not known for [%s].", kind());
}
@JsonProperty("partitions")
public int getMaxPartitions()
{
return maxPartitions;
} }
@Override @Override
@ -113,8 +137,8 @@ public class MaxCountShuffleSpec implements ShuffleSpec
if (o == null || getClass() != o.getClass()) { if (o == null || getClass() != o.getClass()) {
return false; return false;
} }
MaxCountShuffleSpec that = (MaxCountShuffleSpec) o; GlobalSortMaxCountShuffleSpec that = (GlobalSortMaxCountShuffleSpec) o;
return partitions == that.partitions return maxPartitions == that.maxPartitions
&& aggregate == that.aggregate && aggregate == that.aggregate
&& Objects.equals(clusterBy, that.clusterBy); && Objects.equals(clusterBy, that.clusterBy);
} }
@ -122,7 +146,7 @@ public class MaxCountShuffleSpec implements ShuffleSpec
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(clusterBy, partitions, aggregate); return Objects.hash(clusterBy, maxPartitions, aggregate);
} }
@Override @Override
@ -130,7 +154,7 @@ public class MaxCountShuffleSpec implements ShuffleSpec
{ {
return "MaxCountShuffleSpec{" + return "MaxCountShuffleSpec{" +
"clusterBy=" + clusterBy + "clusterBy=" + clusterBy +
", partitions=" + partitions + ", partitions=" + maxPartitions +
", aggregate=" + aggregate + ", aggregate=" + aggregate +
'}'; '}';
} }

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.kernel;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import javax.annotation.Nullable;
/**
* Additional methods for {@link ShuffleSpec} of kind {@link ShuffleKind#GLOBAL_SORT}.
*/
public interface GlobalSortShuffleSpec extends ShuffleSpec
{
/**
* Whether {@link #generatePartitionsForGlobalSort} needs a nonnull collector in order to do its work.
*/
boolean mustGatherResultKeyStatistics();
/**
* Generates a set of partitions based on the provided statistics.
*
* Only valid if {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}. Otherwise, throws {@link IllegalStateException}.
*
* @param collector must be nonnull if {@link #mustGatherResultKeyStatistics()} is true; ignored otherwise
* @param maxNumPartitions maximum number of partitions to generate
*
* @return either the partition assignment, or (as an error) a number of partitions, greater than maxNumPartitions,
* that would be expected to be created
*
* @throws IllegalStateException if {@link #kind()} is not {@link ShuffleKind#GLOBAL_SORT}.
*/
Either<Long, ClusterByPartitions> generatePartitionsForGlobalSort(
@Nullable ClusterByStatisticsCollector collector,
int maxNumPartitions
);
}

View File

@ -26,6 +26,8 @@ import com.google.common.base.Preconditions;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -36,14 +38,16 @@ import java.util.Objects;
* to a particular {@link #targetSize}. Commonly used when generating segments, which we want to have a certain number * to a particular {@link #targetSize}. Commonly used when generating segments, which we want to have a certain number
* of rows per segment. * of rows per segment.
*/ */
public class TargetSizeShuffleSpec implements ShuffleSpec public class GlobalSortTargetSizeShuffleSpec implements GlobalSortShuffleSpec
{ {
public static final String TYPE = "targetSize";
private final ClusterBy clusterBy; private final ClusterBy clusterBy;
private final long targetSize; private final long targetSize;
private final boolean aggregate; private final boolean aggregate;
@JsonCreator @JsonCreator
public TargetSizeShuffleSpec( public GlobalSortTargetSizeShuffleSpec(
@JsonProperty("clusterBy") final ClusterBy clusterBy, @JsonProperty("clusterBy") final ClusterBy clusterBy,
@JsonProperty("targetSize") final long targetSize, @JsonProperty("targetSize") final long targetSize,
@JsonProperty("aggregate") final boolean aggregate @JsonProperty("aggregate") final boolean aggregate
@ -52,24 +56,40 @@ public class TargetSizeShuffleSpec implements ShuffleSpec
this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy"); this.clusterBy = Preconditions.checkNotNull(clusterBy, "clusterBy");
this.targetSize = targetSize; this.targetSize = targetSize;
this.aggregate = aggregate; this.aggregate = aggregate;
if (!clusterBy.sortable()) {
throw new IAE("ClusterBy key must be sortable");
}
}
@Override
public ShuffleKind kind()
{
return ShuffleKind.GLOBAL_SORT;
} }
@Override @Override
@JsonProperty("aggregate") @JsonProperty("aggregate")
@JsonInclude(JsonInclude.Include.NON_DEFAULT) @JsonInclude(JsonInclude.Include.NON_DEFAULT)
public boolean doesAggregateByClusterKey() public boolean doesAggregate()
{ {
return aggregate; return aggregate;
} }
@Override @Override
public boolean needsStatistics() public boolean mustGatherResultKeyStatistics()
{ {
return true; return true;
} }
@Override @Override
public Either<Long, ClusterByPartitions> generatePartitions( public int partitionCount()
{
throw new ISE("Number of partitions not known for [%s].", kind());
}
@Override
public Either<Long, ClusterByPartitions> generatePartitionsForGlobalSort(
@Nullable final ClusterByStatisticsCollector collector, @Nullable final ClusterByStatisticsCollector collector,
final int maxNumPartitions final int maxNumPartitions
) )
@ -90,13 +110,13 @@ public class TargetSizeShuffleSpec implements ShuffleSpec
@Override @Override
@JsonProperty @JsonProperty
public ClusterBy getClusterBy() public ClusterBy clusterBy()
{ {
return clusterBy; return clusterBy;
} }
@JsonProperty @JsonProperty
long getTargetSize() long targetSize()
{ {
return targetSize; return targetSize;
} }
@ -110,7 +130,7 @@ public class TargetSizeShuffleSpec implements ShuffleSpec
if (o == null || getClass() != o.getClass()) { if (o == null || getClass() != o.getClass()) {
return false; return false;
} }
TargetSizeShuffleSpec that = (TargetSizeShuffleSpec) o; GlobalSortTargetSizeShuffleSpec that = (GlobalSortTargetSizeShuffleSpec) o;
return targetSize == that.targetSize && aggregate == that.aggregate && Objects.equals(clusterBy, that.clusterBy); return targetSize == that.targetSize && aggregate == that.aggregate && Objects.equals(clusterBy, that.clusterBy);
} }

View File

@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.kernel;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.java.util.common.IAE;
public class HashShuffleSpec implements ShuffleSpec
{
public static final String TYPE = "hash";
private final ClusterBy clusterBy;
private final int numPartitions;
@JsonCreator
public HashShuffleSpec(
@JsonProperty("clusterBy") final ClusterBy clusterBy,
@JsonProperty("partitions") final int numPartitions
)
{
this.clusterBy = clusterBy;
this.numPartitions = numPartitions;
if (clusterBy.getBucketByCount() > 0) {
// Only GlobalSortTargetSizeShuffleSpec supports bucket-by.
throw new IAE("Cannot bucket with %s partitioning (clusterBy = %s)", TYPE, clusterBy);
}
}
@Override
public ShuffleKind kind()
{
return clusterBy.sortable() && !clusterBy.isEmpty() ? ShuffleKind.HASH_LOCAL_SORT : ShuffleKind.HASH;
}
@Override
@JsonProperty
public ClusterBy clusterBy()
{
return clusterBy;
}
@Override
public boolean doesAggregate()
{
return false;
}
@Override
@JsonProperty("partitions")
public int partitionCount()
{
return numPartitions;
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.kernel;
import com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.druid.frame.key.ClusterBy;
/**
* Shuffle spec that generates a single, unsorted partition.
*/
public class MixShuffleSpec implements ShuffleSpec
{
public static final String TYPE = "mix";
private static final MixShuffleSpec INSTANCE = new MixShuffleSpec();
private MixShuffleSpec()
{
}
@JsonCreator
public static MixShuffleSpec instance()
{
return INSTANCE;
}
@Override
public ShuffleKind kind()
{
return ShuffleKind.MIX;
}
@Override
public ClusterBy clusterBy()
{
return ClusterBy.none();
}
@Override
public boolean doesAggregate()
{
return false;
}
@Override
public int partitionCount()
{
return 1;
}
@Override
public boolean equals(Object obj)
{
return obj != null && this.getClass().equals(obj.getClass());
}
@Override
public int hashCode()
{
return 0;
}
@Override
public String toString()
{
return "MuxShuffleSpec{}";
}
}

View File

@ -20,6 +20,7 @@
package org.apache.druid.msq.kernel; package org.apache.druid.msq.kernel;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.java.util.common.ISE;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -77,6 +78,17 @@ public class QueryDefinitionBuilder
return stageBuilders.stream().mapToInt(StageDefinitionBuilder::getStageNumber).max().orElse(-1) + 1; return stageBuilders.stream().mapToInt(StageDefinitionBuilder::getStageNumber).max().orElse(-1) + 1;
} }
public StageDefinitionBuilder getStageBuilder(final int stageNumber)
{
for (final StageDefinitionBuilder stageBuilder : stageBuilders) {
if (stageBuilder.getStageNumber() == stageNumber) {
return stageBuilder;
}
}
throw new ISE("No such stage [%s]", stageNumber);
}
public QueryDefinition build() public QueryDefinition build()
{ {
final List<StageDefinition> stageDefinitions = final List<StageDefinition> stageDefinitions =

View File

@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.kernel;
public enum ShuffleKind
{
/**
* Put all data in a single partition, with no sorting and no statistics gathering.
*/
MIX(false, false),
/**
* Partition using hash codes, with no sorting.
*
* This kind of shuffle supports pipelining: producer and consumer stages can run at the same time.
*/
HASH(true, false),
/**
* Partition using hash codes, with each partition internally sorted.
*
* Each worker partitions its outputs according to hash code of the cluster key, and does a local sort of its
* own outputs.
*
* Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before
* consumer stages can run.
*/
HASH_LOCAL_SORT(true, true),
/**
* Partition using a distributed global sort.
*
* First, each worker reads its input fully and feeds statistics into a
* {@link org.apache.druid.msq.statistics.ClusterByStatisticsCollector}. The controller merges those statistics,
* generating final {@link org.apache.druid.frame.key.ClusterByPartitions}. Then, workers fully sort and partition
* their outputs along those lines.
*
* Consumers (workers in the next stage downstream) do an N-way merge of the already-sorted and already-partitioned
* output files from each worker.
*
* Due to the need to sort outputs, this shuffle mechanism cannot be pipelined. Producer stages must finish before
* consumer stages can run.
*/
GLOBAL_SORT(false, true);
private final boolean hash;
private final boolean sort;
ShuffleKind(boolean hash, boolean sort)
{
this.hash = hash;
this.sort = sort;
}
/**
* Whether this shuffle does hash-partitioning.
*/
public boolean isHash()
{
return hash;
}
/**
* Whether this shuffle sorts within partitions. (If true, it may, or may not, also sort globally.)
*/
public boolean isSort()
{
return sort;
}
}

View File

@ -22,11 +22,6 @@ package org.apache.druid.msq.kernel;
import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonSubTypes;
import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.java.util.common.Either;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import javax.annotation.Nullable;
/** /**
* Describes how outputs of a stage are shuffled. Property of {@link StageDefinition}. * Describes how outputs of a stage are shuffled. Property of {@link StageDefinition}.
@ -36,37 +31,46 @@ import javax.annotation.Nullable;
*/ */
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
@JsonSubTypes(value = { @JsonSubTypes(value = {
@JsonSubTypes.Type(name = "maxCount", value = MaxCountShuffleSpec.class), @JsonSubTypes.Type(name = MixShuffleSpec.TYPE, value = MixShuffleSpec.class),
@JsonSubTypes.Type(name = "targetSize", value = TargetSizeShuffleSpec.class) @JsonSubTypes.Type(name = HashShuffleSpec.TYPE, value = HashShuffleSpec.class),
@JsonSubTypes.Type(name = GlobalSortMaxCountShuffleSpec.TYPE, value = GlobalSortMaxCountShuffleSpec.class),
@JsonSubTypes.Type(name = GlobalSortTargetSizeShuffleSpec.TYPE, value = GlobalSortTargetSizeShuffleSpec.class)
}) })
public interface ShuffleSpec public interface ShuffleSpec
{ {
/** /**
* Clustering key that will determine how data are partitioned during the shuffle. * The nature of this shuffle: hash vs. range based partitioning; whether the data are sorted or not.
*/
ClusterBy getClusterBy();
/**
* Whether this stage aggregates by the clustering key or not.
*/
boolean doesAggregateByClusterKey();
/**
* Whether {@link #generatePartitions} needs a nonnull collector.
*/
boolean needsStatistics();
/**
* Generates a set of partitions based on the provided statistics.
* *
* @param collector must be nonnull if {@link #needsStatistics()} is true; may be null otherwise * If this method returns {@link ShuffleKind#GLOBAL_SORT}, then this spec is also an instance of
* @param maxNumPartitions maximum number of partitions to generate * {@link GlobalSortShuffleSpec}, and additional methods are available.
*
* @return either the partition assignment, or (as an error) a number of partitions, greater than maxNumPartitions,
* that would be expected to be created
*/ */
Either<Long, ClusterByPartitions> generatePartitions( ShuffleKind kind();
@Nullable ClusterByStatisticsCollector collector,
int maxNumPartitions /**
); * Partitioning key for the shuffle.
*
* If {@link #kind()} is {@link ShuffleKind#HASH}, data are partitioned using a hash of this key, but not sorted.
*
* If {@link #kind()} is {@link ShuffleKind#HASH_LOCAL_SORT}, data are partitioned using a hash of this key, and
* sorted within each partition.
*
* If {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}, data are partitioned using ranges of this key, and are
* sorted within each partition; therefore, the data are also globally sorted.
*/
ClusterBy clusterBy();
/**
* Whether this stage aggregates by the {@link #clusterBy()} key.
*/
boolean doesAggregate();
/**
* Number of partitions, if known.
*
* Partition count is always known if {@link #kind()} is {@link ShuffleKind#MIX}, {@link ShuffleKind#HASH}, or
* {@link ShuffleKind#HASH_LOCAL_SORT}. It is not known if {@link #kind()} is {@link ShuffleKind#GLOBAL_SORT}.
*
* @throws IllegalStateException if kind is {@link ShuffleKind#GLOBAL_SORT}
*/
int partitionCount();
} }

View File

@ -27,9 +27,16 @@ import com.google.common.base.Suppliers;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets; import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.allocation.MemoryAllocatorFactory;
import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.java.util.common.Either; import org.apache.druid.java.util.common.Either;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
@ -41,9 +48,9 @@ import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -64,7 +71,7 @@ import java.util.function.Supplier;
* Each stage has a {@link ShuffleSpec} describing the shuffle that occurs as part of the stage. The shuffle spec is * Each stage has a {@link ShuffleSpec} describing the shuffle that occurs as part of the stage. The shuffle spec is
* optional: if none is provided, then the {@link FrameProcessorFactory} directly writes to output partitions. If a * optional: if none is provided, then the {@link FrameProcessorFactory} directly writes to output partitions. If a
* shuffle spec is provided, then the {@link FrameProcessorFactory} is expected to sort each output frame individually * shuffle spec is provided, then the {@link FrameProcessorFactory} is expected to sort each output frame individually
* according to {@link ShuffleSpec#getClusterBy()}. The execution system handles the rest, including sorting data across * according to {@link ShuffleSpec#clusterBy()}. The execution system handles the rest, including sorting data across
* frames and producing the appropriate output partitions. * frames and producing the appropriate output partitions.
* <p> * <p>
* The rarely-used parameter {@link #getShuffleCheckHasMultipleValues()} controls whether the execution system * The rarely-used parameter {@link #getShuffleCheckHasMultipleValues()} controls whether the execution system
@ -128,7 +135,7 @@ public class StageDefinition
this.maxInputBytesPerWorker = maxInputBytesPerWorker == null ? this.maxInputBytesPerWorker = maxInputBytesPerWorker == null ?
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER : maxInputBytesPerWorker; Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER : maxInputBytesPerWorker;
if (shuffleSpec != null && shuffleSpec.needsStatistics() && shuffleSpec.getClusterBy().getColumns().isEmpty()) { if (mustGatherResultKeyStatistics() && shuffleSpec.clusterBy().getColumns().isEmpty()) {
throw new IAE("Cannot shuffle with spec [%s] and nil clusterBy", shuffleSpec); throw new IAE("Cannot shuffle with spec [%s] and nil clusterBy", shuffleSpec);
} }
@ -157,7 +164,7 @@ public class StageDefinition
.broadcastInputs(stageDef.getBroadcastInputNumbers()) .broadcastInputs(stageDef.getBroadcastInputNumbers())
.processorFactory(stageDef.getProcessorFactory()) .processorFactory(stageDef.getProcessorFactory())
.signature(stageDef.getSignature()) .signature(stageDef.getSignature())
.shuffleSpec(stageDef.getShuffleSpec().orElse(null)) .shuffleSpec(stageDef.doesShuffle() ? stageDef.getShuffleSpec() : null)
.maxWorkerCount(stageDef.getMaxWorkerCount()) .maxWorkerCount(stageDef.getMaxWorkerCount())
.shuffleCheckHasMultipleValues(stageDef.getShuffleCheckHasMultipleValues()); .shuffleCheckHasMultipleValues(stageDef.getShuffleCheckHasMultipleValues());
} }
@ -212,16 +219,25 @@ public class StageDefinition
public boolean doesSortDuringShuffle() public boolean doesSortDuringShuffle()
{ {
if (shuffleSpec == null) { if (shuffleSpec == null || shuffleSpec.clusterBy().isEmpty()) {
return false; return false;
} else { } else {
return !shuffleSpec.getClusterBy().getColumns().isEmpty() || shuffleSpec.needsStatistics(); return shuffleSpec.clusterBy().sortable();
} }
} }
public Optional<ShuffleSpec> getShuffleSpec() /**
* Returns the {@link ShuffleSpec} for this stage, if {@link #doesShuffle()}.
*
* @throws IllegalStateException if this stage does not shuffle
*/
public ShuffleSpec getShuffleSpec()
{ {
return Optional.ofNullable(shuffleSpec); if (shuffleSpec == null) {
throw new IllegalStateException("Stage does not shuffle");
}
return shuffleSpec;
} }
/** /**
@ -229,7 +245,25 @@ public class StageDefinition
*/ */
public ClusterBy getClusterBy() public ClusterBy getClusterBy()
{ {
return shuffleSpec != null ? shuffleSpec.getClusterBy() : ClusterBy.none(); if (shuffleSpec != null) {
return shuffleSpec.clusterBy();
} else {
return ClusterBy.none();
}
}
/**
* Returns the key used for sorting each individual partition, or an empty list if partitions are unsorted.
*/
public List<KeyColumn> getSortKey()
{
final ClusterBy clusterBy = getClusterBy();
if (clusterBy.sortable()) {
return clusterBy.getColumns();
} else {
return Collections.emptyList();
}
} }
@Nullable @Nullable
@ -285,40 +319,77 @@ public class StageDefinition
*/ */
public boolean mustGatherResultKeyStatistics() public boolean mustGatherResultKeyStatistics()
{ {
return shuffleSpec != null && shuffleSpec.needsStatistics(); return shuffleSpec != null
&& shuffleSpec.kind() == ShuffleKind.GLOBAL_SORT
&& ((GlobalSortShuffleSpec) shuffleSpec).mustGatherResultKeyStatistics();
} }
public Either<Long, ClusterByPartitions> generatePartitionsForShuffle( public Either<Long, ClusterByPartitions> generatePartitionBoundariesForShuffle(
@Nullable ClusterByStatisticsCollector collector @Nullable ClusterByStatisticsCollector collector
) )
{ {
if (shuffleSpec == null) { if (shuffleSpec == null) {
throw new ISE("No shuffle for stage[%d]", getStageNumber()); throw new ISE("No shuffle for stage[%d]", getStageNumber());
} else if (shuffleSpec.kind() != ShuffleKind.GLOBAL_SORT) {
throw new ISE(
"Shuffle of kind [%s] cannot generate partition boundaries for stage[%d]",
shuffleSpec.kind(),
getStageNumber()
);
} else if (mustGatherResultKeyStatistics() && collector == null) { } else if (mustGatherResultKeyStatistics() && collector == null) {
throw new ISE("Statistics required, but not gathered for stage[%d]", getStageNumber()); throw new ISE("Statistics required, but not gathered for stage[%d]", getStageNumber());
} else if (!mustGatherResultKeyStatistics() && collector != null) { } else if (!mustGatherResultKeyStatistics() && collector != null) {
throw new ISE("Statistics gathered, but not required for stage[%d]", getStageNumber()); throw new ISE("Statistics gathered, but not required for stage[%d]", getStageNumber());
} else { } else {
return shuffleSpec.generatePartitions(collector, MAX_PARTITIONS); return ((GlobalSortShuffleSpec) shuffleSpec).generatePartitionsForGlobalSort(collector, MAX_PARTITIONS);
} }
} }
public ClusterByStatisticsCollector createResultKeyStatisticsCollector(final int maxRetainedBytes) public ClusterByStatisticsCollector createResultKeyStatisticsCollector(final int maxRetainedBytes)
{ {
if (!mustGatherResultKeyStatistics()) { if (!mustGatherResultKeyStatistics()) {
throw new ISE("No statistics needed"); throw new ISE("No statistics needed for stage[%d]", getStageNumber());
} }
return ClusterByStatisticsCollectorImpl.create( return ClusterByStatisticsCollectorImpl.create(
shuffleSpec.getClusterBy(), shuffleSpec.clusterBy(),
signature, signature,
maxRetainedBytes, maxRetainedBytes,
PARTITION_STATS_MAX_BUCKETS, PARTITION_STATS_MAX_BUCKETS,
shuffleSpec.doesAggregateByClusterKey(), shuffleSpec.doesAggregate(),
shuffleCheckHasMultipleValues shuffleCheckHasMultipleValues
); );
} }
/**
* Create the {@link FrameWriterFactory} that must be used by {@link #getProcessorFactory()}.
*
* Calls {@link MemoryAllocatorFactory#newAllocator()} for each frame.
*/
public FrameWriterFactory createFrameWriterFactory(final MemoryAllocatorFactory memoryAllocatorFactory)
{
return FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED,
memoryAllocatorFactory,
signature,
// Main processor does not sort when there is a hash going on, even if isSort = true. This is because
// FrameChannelHashPartitioner is expected to be attached to the processor and do the sorting. We don't
// want to double-sort.
doesShuffle() && !shuffleSpec.kind().isHash() ? getClusterBy().getColumns() : Collections.emptyList()
);
}
/**
* Create the {@link FrameWriterFactory} that must be used by {@link #getProcessorFactory()}.
*
* Re-uses the same {@link MemoryAllocator} for each frame.
*/
public FrameWriterFactory createFrameWriterFactory(final MemoryAllocator allocator)
{
return createFrameWriterFactory(new SingleMemoryAllocatorFactory(allocator));
}
public FrameReader getFrameReader() public FrameReader getFrameReader()
{ {
return frameReader.get(); return frameReader.get();

View File

@ -118,6 +118,11 @@ public class StageDefinitionBuilder
return stageNumber; return stageNumber;
} }
public RowSignature getSignature()
{
return signature;
}
public StageDefinition build(final String queryId) public StageDefinition build(final String queryId)
{ {
return new StageDefinition( return new StageDefinition(

View File

@ -43,6 +43,9 @@ import org.apache.druid.msq.input.InputSpecSlicer;
import org.apache.druid.msq.input.stage.ReadablePartition; import org.apache.druid.msq.input.stage.ReadablePartition;
import org.apache.druid.msq.input.stage.ReadablePartitions; import org.apache.druid.msq.input.stage.ReadablePartitions;
import org.apache.druid.msq.input.stage.StageInputSlice; import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.GlobalSortShuffleSpec;
import org.apache.druid.msq.kernel.ShuffleKind;
import org.apache.druid.msq.kernel.ShuffleSpec;
import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
@ -556,7 +559,8 @@ class ControllerStageTracker
if (workers.isEmpty()) { if (workers.isEmpty()) {
// generate partition boundaries since all work is finished for the time chunk // generate partition boundaries since all work is finished for the time chunk
ClusterByStatisticsCollector collector = timeChunkToCollector.get(tc); ClusterByStatisticsCollector collector = timeChunkToCollector.get(tc);
Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionsForShuffle(collector); Either<Long, ClusterByPartitions> countOrPartitions =
stageDef.generatePartitionBoundariesForShuffle(collector);
totalPartitionCount += getPartitionCountFromEither(countOrPartitions); totalPartitionCount += getPartitionCountFromEither(countOrPartitions);
if (totalPartitionCount > stageDef.getMaxPartitionCount()) { if (totalPartitionCount > stageDef.getMaxPartitionCount()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount())); failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
@ -689,8 +693,8 @@ class ControllerStageTracker
); );
} }
if (resultPartitions == null) { if (resultPartitions == null) {
Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionsForShuffle(timeChunkToCollector.get( final ClusterByStatisticsCollector collector = timeChunkToCollector.get(STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE);
STATIC_TIME_CHUNK_FOR_PARALLEL_MERGE)); Either<Long, ClusterByPartitions> countOrPartitions = stageDef.generatePartitionBoundariesForShuffle(collector);
totalPartitionCount += getPartitionCountFromEither(countOrPartitions); totalPartitionCount += getPartitionCountFromEither(countOrPartitions);
if (totalPartitionCount > stageDef.getMaxPartitionCount()) { if (totalPartitionCount > stageDef.getMaxPartitionCount()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount())); failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
@ -840,9 +844,10 @@ class ControllerStageTracker
} }
/** /**
* Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries} without using key statistics. * Sets {@link #resultPartitions} (always) and {@link #resultPartitionBoundaries} (if doing a global sort) without
* <p> * using key statistics. Called by the constructor.
* If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method should not be called. *
* If {@link StageDefinition#mustGatherResultKeyStatistics()} is true, this method must not be called.
*/ */
private void generateResultPartitionsAndBoundariesWithoutKeyStatistics() private void generateResultPartitionsAndBoundariesWithoutKeyStatistics()
{ {
@ -856,24 +861,31 @@ class ControllerStageTracker
final int stageNumber = stageDef.getStageNumber(); final int stageNumber = stageDef.getStageNumber();
if (stageDef.doesShuffle()) { if (stageDef.doesShuffle()) {
if (stageDef.mustGatherResultKeyStatistics() && !allPartialKeyInformationFetched()) { final ShuffleSpec shuffleSpec = stageDef.getShuffleSpec();
throw new ISE("Cannot generate result partitions without all worker key statistics");
if (shuffleSpec.kind() == ShuffleKind.GLOBAL_SORT) {
if (((GlobalSortShuffleSpec) shuffleSpec).mustGatherResultKeyStatistics()
&& !allPartialKeyInformationFetched()) {
throw new ISE("Cannot generate result partitions without all worker key statistics");
}
final Either<Long, ClusterByPartitions> maybeResultPartitionBoundaries =
stageDef.generatePartitionBoundariesForShuffle(null);
if (maybeResultPartitionBoundaries.isError()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
return;
}
resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerCount,
resultPartitionBoundaries.size()
);
} else {
resultPartitions = ReadablePartitions.striped(stageNumber, workerCount, shuffleSpec.partitionCount());
} }
final Either<Long, ClusterByPartitions> maybeResultPartitionBoundaries =
stageDef.generatePartitionsForShuffle(null);
if (maybeResultPartitionBoundaries.isError()) {
failForReason(new TooManyPartitionsFault(stageDef.getMaxPartitionCount()));
return;
}
resultPartitionBoundaries = maybeResultPartitionBoundaries.valueOrThrow();
resultPartitions = ReadablePartitions.striped(
stageNumber,
workerCount,
resultPartitionBoundaries.size()
);
} else { } else {
// No reshuffling: retain partitioning from nonbroadcast inputs. // No reshuffling: retain partitioning from nonbroadcast inputs.
final Int2IntSortedMap partitionToWorkerMap = new Int2IntAVLTreeMap(); final Int2IntSortedMap partitionToWorkerMap = new Int2IntAVLTreeMap();

View File

@ -25,6 +25,7 @@ import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.kernel.ShuffleKind;
import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder; import org.apache.druid.msq.kernel.WorkOrder;
@ -70,11 +71,12 @@ public class WorkerStageKernel
this.workOrder = workOrder; this.workOrder = workOrder;
if (workOrder.getStageDefinition().doesShuffle() if (workOrder.getStageDefinition().doesShuffle()
&& workOrder.getStageDefinition().getShuffleSpec().kind() == ShuffleKind.GLOBAL_SORT
&& !workOrder.getStageDefinition().mustGatherResultKeyStatistics()) { && !workOrder.getStageDefinition().mustGatherResultKeyStatistics()) {
// Use valueOrThrow instead of a nicer error collection mechanism, because we really don't expect the // Use valueOrThrow instead of a nicer error collection mechanism, because we really don't expect the
// MAX_PARTITIONS to be exceeded here. It would involve having a shuffleSpec that was statically configured // MAX_PARTITIONS to be exceeded here. It would involve having a shuffleSpec that was statically configured
// to use a huge number of partitions. // to use a huge number of partitions.
resultPartitionBoundaries = workOrder.getStageDefinition().generatePartitionsForShuffle(null).valueOrThrow(); resultPartitionBoundaries = workOrder.getStageDefinition().generatePartitionBoundariesForShuffle(null).valueOrThrow();
} }
} }

View File

@ -23,15 +23,14 @@ import com.google.common.collect.Iterators;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import org.apache.druid.collections.ResourceHolder; import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.channel.ReadableConcatFrameChannel; import org.apache.druid.frame.channel.ReadableConcatFrameChannel;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannel; import org.apache.druid.frame.processor.OutputChannel;
import org.apache.druid.frame.processor.OutputChannelFactory; import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.OutputChannels; import org.apache.druid.frame.processor.OutputChannels;
import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.guava.Sequence; import org.apache.druid.java.util.common.guava.Sequence;
@ -48,7 +47,6 @@ import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.FrameContext;
import org.apache.druid.msq.kernel.ProcessorsAndChannels; import org.apache.druid.msq.kernel.ProcessorsAndChannels;
import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.segment.column.RowSignature;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.IOException; import java.io.IOException;
@ -104,7 +102,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
outstandingProcessors = Math.min(totalProcessors, maxOutstandingProcessors); outstandingProcessors = Math.min(totalProcessors, maxOutstandingProcessors);
} }
final AtomicReference<Queue<MemoryAllocator>> allocatorQueueRef = final AtomicReference<Queue<FrameWriterFactory>> frameWriterFactoryQueueRef =
new AtomicReference<>(new ArrayDeque<>(outstandingProcessors)); new AtomicReference<>(new ArrayDeque<>(outstandingProcessors));
final AtomicReference<Queue<WritableFrameChannel>> channelQueueRef = final AtomicReference<Queue<WritableFrameChannel>> channelQueueRef =
new AtomicReference<>(new ArrayDeque<>(outstandingProcessors)); new AtomicReference<>(new ArrayDeque<>(outstandingProcessors));
@ -114,7 +112,9 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
final OutputChannel outputChannel = outputChannelFactory.openChannel(0 /* Partition number doesn't matter */); final OutputChannel outputChannel = outputChannelFactory.openChannel(0 /* Partition number doesn't matter */);
outputChannels.add(outputChannel); outputChannels.add(outputChannel);
channelQueueRef.get().add(outputChannel.getWritableChannel()); channelQueueRef.get().add(outputChannel.getWritableChannel());
allocatorQueueRef.get().add(outputChannel.getFrameMemoryAllocator()); frameWriterFactoryQueueRef.get().add(
stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator())
);
} }
// Read all base inputs in separate processors, one per processor. // Read all base inputs in separate processors, one per processor.
@ -147,9 +147,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
} }
} }
), ),
makeLazyResourceHolder(allocatorQueueRef, ignored -> {}), makeLazyResourceHolder(frameWriterFactoryQueueRef, ignored -> {}),
stageDefinition.getSignature(),
stageDefinition.getClusterBy(),
frameContext frameContext
); );
} }
@ -257,9 +255,7 @@ public abstract class BaseLeafFrameProcessorFactory extends BaseFrameProcessorFa
ReadableInput baseInput, ReadableInput baseInput,
Int2ObjectMap<ReadableInput> sideChannels, Int2ObjectMap<ReadableInput> sideChannels,
ResourceHolder<WritableFrameChannel> outputChannelSupplier, ResourceHolder<WritableFrameChannel> outputChannelSupplier,
ResourceHolder<MemoryAllocator> allocatorSupplier, ResourceHolder<FrameWriterFactory> frameWriterFactory,
RowSignature signature,
ClusterBy clusterBy,
FrameContext providerThingy FrameContext providerThingy
); );

View File

@ -23,11 +23,15 @@ import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets; import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.druid.data.input.impl.InlineInputSource; import org.apache.druid.data.input.impl.InlineInputSource;
import org.apache.druid.data.input.impl.JsonInputFormat; import org.apache.druid.data.input.impl.JsonInputFormat;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.UOE; import org.apache.druid.java.util.common.UOE;
@ -36,11 +40,16 @@ import org.apache.druid.msq.input.NilInputSource;
import org.apache.druid.msq.input.external.ExternalInputSpec; import org.apache.druid.msq.input.external.ExternalInputSpec;
import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.input.table.TableInputSpec; import org.apache.druid.msq.input.table.TableInputSpec;
import org.apache.druid.msq.kernel.HashShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageDefinitionBuilder;
import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
import org.apache.druid.query.DataSource; import org.apache.druid.query.DataSource;
import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.JoinDataSource; import org.apache.druid.query.JoinDataSource;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource; import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.filter.DimFilter; import org.apache.druid.query.filter.DimFilter;
@ -52,6 +61,8 @@ import org.apache.druid.segment.column.ColumnHolder;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.sql.calcite.external.ExternalDataSource; import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlInsert;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -63,9 +74,6 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/**
* Used by {@link QueryKit} implementations to produce {@link InputSpec} from native {@link DataSource}.
*/
public class DataSourcePlan public class DataSourcePlan
{ {
/** /**
@ -108,6 +116,7 @@ public class DataSourcePlan
public static DataSourcePlan forDataSource( public static DataSourcePlan forDataSource(
final QueryKit queryKit, final QueryKit queryKit,
final String queryId, final String queryId,
final QueryContext queryContext,
final DataSource dataSource, final DataSource dataSource,
final QuerySegmentSpec querySegmentSpec, final QuerySegmentSpec querySegmentSpec,
@Nullable DimFilter filter, @Nullable DimFilter filter,
@ -135,15 +144,35 @@ public class DataSourcePlan
broadcast broadcast
); );
} else if (dataSource instanceof JoinDataSource) { } else if (dataSource instanceof JoinDataSource) {
return forJoin( final JoinAlgorithm joinAlgorithm = PlannerContext.getJoinAlgorithm(queryContext);
queryKit,
queryId, switch (joinAlgorithm) {
(JoinDataSource) dataSource, case BROADCAST:
querySegmentSpec, return forBroadcastHashJoin(
maxWorkerCount, queryKit,
minStageNumber, queryId,
broadcast queryContext,
); (JoinDataSource) dataSource,
querySegmentSpec,
maxWorkerCount,
minStageNumber,
broadcast
);
case SORT_MERGE:
return forSortMergeJoin(
queryKit,
queryId,
(JoinDataSource) dataSource,
querySegmentSpec,
maxWorkerCount,
minStageNumber,
broadcast
);
default:
throw new UOE("Cannot handle join algorithm [%s]", joinAlgorithm);
}
} else { } else {
throw new UOE("Cannot handle dataSource [%s]", dataSource); throw new UOE("Cannot handle dataSource [%s]", dataSource);
} }
@ -263,7 +292,7 @@ public class DataSourcePlan
// outermost query, and setting it for the subquery makes us erroneously add bucketing where it doesn't belong. // outermost query, and setting it for the subquery makes us erroneously add bucketing where it doesn't belong.
dataSource.getQuery().withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY), dataSource.getQuery().withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY),
queryKit, queryKit,
ShuffleSpecFactories.subQueryWithMaxWorkerCount(maxWorkerCount), ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount),
maxWorkerCount, maxWorkerCount,
minStageNumber minStageNumber
); );
@ -278,9 +307,13 @@ public class DataSourcePlan
); );
} }
private static DataSourcePlan forJoin( /**
* Build a plan for broadcast hash-join.
*/
private static DataSourcePlan forBroadcastHashJoin(
final QueryKit queryKit, final QueryKit queryKit,
final String queryId, final String queryId,
final QueryContext queryContext,
final JoinDataSource dataSource, final JoinDataSource dataSource,
final QuerySegmentSpec querySegmentSpec, final QuerySegmentSpec querySegmentSpec,
final int maxWorkerCount, final int maxWorkerCount,
@ -294,11 +327,13 @@ public class DataSourcePlan
final DataSourcePlan basePlan = forDataSource( final DataSourcePlan basePlan = forDataSource(
queryKit, queryKit,
queryId, queryId,
queryContext,
analysis.getBaseDataSource(), analysis.getBaseDataSource(),
querySegmentSpec, querySegmentSpec,
null, // Don't push query filters down through a join: this needs some work to ensure pruning works properly. null, // Don't push query filters down through a join: this needs some work to ensure pruning works properly.
maxWorkerCount, maxWorkerCount,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()), Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
broadcast broadcast
); );
@ -312,6 +347,7 @@ public class DataSourcePlan
final DataSourcePlan clausePlan = forDataSource( final DataSourcePlan clausePlan = forDataSource(
queryKit, queryKit,
queryId, queryId,
queryContext,
clause.getDataSource(), clause.getDataSource(),
new MultipleIntervalSegmentSpec(Intervals.ONLY_ETERNITY), new MultipleIntervalSegmentSpec(Intervals.ONLY_ETERNITY),
null, // Don't push query filters down through a join: this needs some work to ensure pruning works properly. null, // Don't push query filters down through a join: this needs some work to ensure pruning works properly.
@ -341,6 +377,117 @@ public class DataSourcePlan
return new DataSourcePlan(newDataSource, inputSpecs, broadcastInputs, subQueryDefBuilder); return new DataSourcePlan(newDataSource, inputSpecs, broadcastInputs, subQueryDefBuilder);
} }
/**
* Build a plan for sort-merge join.
*/
private static DataSourcePlan forSortMergeJoin(
final QueryKit queryKit,
final String queryId,
final JoinDataSource dataSource,
final QuerySegmentSpec querySegmentSpec,
final int maxWorkerCount,
final int minStageNumber,
final boolean broadcast
)
{
checkQuerySegmentSpecIsEternity(dataSource, querySegmentSpec);
SortMergeJoinFrameProcessorFactory.validateCondition(dataSource.getConditionAnalysis());
// Partition by keys given by the join condition.
final List<List<KeyColumn>> partitionKeys = SortMergeJoinFrameProcessorFactory.toKeyColumns(
SortMergeJoinFrameProcessorFactory.validateCondition(dataSource.getConditionAnalysis())
);
final QueryDefinitionBuilder subQueryDefBuilder = QueryDefinition.builder();
// Plan the left input.
// We're confident that we can cast dataSource.getLeft() to QueryDataSource, because DruidJoinQueryRel creates
// subqueries when the join algorithm is sortMerge.
final DataSourcePlan leftPlan = forQuery(
queryKit,
queryId,
(QueryDataSource) dataSource.getLeft(),
maxWorkerCount,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
false
);
leftPlan.getSubQueryDefBuilder().ifPresent(subQueryDefBuilder::addAll);
// Plan the right input.
// We're confident that we can cast dataSource.getRight() to QueryDataSource, because DruidJoinQueryRel creates
// subqueries when the join algorithm is sortMerge.
final DataSourcePlan rightPlan = forQuery(
queryKit,
queryId,
(QueryDataSource) dataSource.getRight(),
maxWorkerCount,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
false
);
rightPlan.getSubQueryDefBuilder().ifPresent(subQueryDefBuilder::addAll);
// Build up the left stage.
final StageDefinitionBuilder leftBuilder = subQueryDefBuilder.getStageBuilder(
((StageInputSpec) Iterables.getOnlyElement(leftPlan.getInputSpecs())).getStageNumber()
);
final List<KeyColumn> leftPartitionKey = partitionKeys.get(0);
leftBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(leftPartitionKey, 0), maxWorkerCount));
leftBuilder.signature(QueryKitUtils.sortableSignature(leftBuilder.getSignature(), leftPartitionKey));
// Build up the right stage.
final StageDefinitionBuilder rightBuilder = subQueryDefBuilder.getStageBuilder(
((StageInputSpec) Iterables.getOnlyElement(rightPlan.getInputSpecs())).getStageNumber()
);
final List<KeyColumn> rightPartitionKey = partitionKeys.get(1);
rightBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(rightPartitionKey, 0), maxWorkerCount));
rightBuilder.signature(QueryKitUtils.sortableSignature(rightBuilder.getSignature(), rightPartitionKey));
// Compute join signature.
final RowSignature.Builder joinSignatureBuilder = RowSignature.builder();
for (String leftColumn : leftBuilder.getSignature().getColumnNames()) {
joinSignatureBuilder.add(leftColumn, leftBuilder.getSignature().getColumnType(leftColumn).orElse(null));
}
for (String rightColumn : rightBuilder.getSignature().getColumnNames()) {
joinSignatureBuilder.add(
dataSource.getRightPrefix() + rightColumn,
rightBuilder.getSignature().getColumnType(rightColumn).orElse(null)
);
}
// Build up the join stage.
final int stageNumber = Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber());
subQueryDefBuilder.add(
StageDefinition.builder(stageNumber)
.inputs(
ImmutableList.of(
Iterables.getOnlyElement(leftPlan.getInputSpecs()),
Iterables.getOnlyElement(rightPlan.getInputSpecs())
)
)
.maxWorkerCount(maxWorkerCount)
.signature(joinSignatureBuilder.build())
.processorFactory(
new SortMergeJoinFrameProcessorFactory(
dataSource.getRightPrefix(),
dataSource.getConditionAnalysis(),
dataSource.getJoinType()
)
)
);
return new DataSourcePlan(
new InputNumberDataSource(0),
Collections.singletonList(new StageInputSpec(stageNumber)),
broadcast ? IntOpenHashSet.of(0) : IntSets.emptySet(),
subQueryDefBuilder
);
}
private static DataSource shiftInputNumbers(final DataSource dataSource, final int shift) private static DataSource shiftInputNumbers(final DataSource dataSource, final int shift)
{ {
if (shift < 0) { if (shift < 0) {

View File

@ -23,7 +23,8 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.calcite.sql.dialect.CalciteSqlDialect; import org.apache.calcite.sql.dialect.CalciteSqlDialect;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
@ -107,8 +108,8 @@ public class QueryKitUtils
if (Granularities.ALL.equals(segmentGranularity)) { if (Granularities.ALL.equals(segmentGranularity)) {
return clusterBy; return clusterBy;
} else { } else {
final List<SortColumn> newColumns = new ArrayList<>(clusterBy.getColumns().size() + 1); final List<KeyColumn> newColumns = new ArrayList<>(clusterBy.getColumns().size() + 1);
newColumns.add(new SortColumn(QueryKitUtils.SEGMENT_GRANULARITY_COLUMN, false)); newColumns.add(new KeyColumn(QueryKitUtils.SEGMENT_GRANULARITY_COLUMN, KeyOrder.ASCENDING));
newColumns.addAll(clusterBy.getColumns()); newColumns.addAll(clusterBy.getColumns());
return new ClusterBy(newColumns, 1); return new ClusterBy(newColumns, 1);
} }
@ -153,12 +154,12 @@ public class QueryKitUtils
*/ */
public static RowSignature sortableSignature( public static RowSignature sortableSignature(
final RowSignature signature, final RowSignature signature,
final List<SortColumn> clusterByColumns final List<KeyColumn> clusterByColumns
) )
{ {
final RowSignature.Builder builder = RowSignature.builder(); final RowSignature.Builder builder = RowSignature.builder();
for (final SortColumn columnName : clusterByColumns) { for (final KeyColumn columnName : clusterByColumns) {
final Optional<ColumnType> columnType = signature.getColumnType(columnName.columnName()); final Optional<ColumnType> columnType = signature.getColumnType(columnName.columnName());
if (!columnType.isPresent()) { if (!columnType.isPresent()) {
throw new IAE("Column [%s] not present in signature", columnName); throw new IAE("Column [%s] not present in signature", columnName);
@ -168,7 +169,7 @@ public class QueryKitUtils
} }
final Set<String> clusterByColumnNames = final Set<String> clusterByColumnNames =
clusterByColumns.stream().map(SortColumn::columnName).collect(Collectors.toSet()); clusterByColumns.stream().map(KeyColumn::columnName).collect(Collectors.toSet());
for (int i = 0; i < signature.size(); i++) { for (int i = 0; i < signature.size(); i++) {
final String columnName = signature.getColumnName(i); final String columnName = signature.getColumnName(i);

View File

@ -19,7 +19,8 @@
package org.apache.druid.msq.querykit; package org.apache.druid.msq.querykit;
import org.apache.druid.msq.kernel.MaxCountShuffleSpec; import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
import org.apache.druid.msq.kernel.MixShuffleSpec;
/** /**
* Static factory methods for common implementations of {@link ShuffleSpecFactory}. * Static factory methods for common implementations of {@link ShuffleSpecFactory}.
@ -32,20 +33,24 @@ public class ShuffleSpecFactories
} }
/** /**
* Factory that produces a single output partition. * Factory that produces a single output partition, which may or may not be sorted.
*/ */
public static ShuffleSpecFactory singlePartition() public static ShuffleSpecFactory singlePartition()
{ {
return (clusterBy, aggregate) -> return (clusterBy, aggregate) -> {
new MaxCountShuffleSpec(clusterBy, 1, aggregate); if (clusterBy.sortable() && !clusterBy.isEmpty()) {
return new GlobalSortMaxCountShuffleSpec(clusterBy, 1, aggregate);
} else {
return MixShuffleSpec.instance();
}
};
} }
/** /**
* Factory that produces a particular number of output partitions. * Factory that produces a particular number of output partitions.
*/ */
public static ShuffleSpecFactory subQueryWithMaxWorkerCount(final int maxWorkerCount) public static ShuffleSpecFactory globalSortWithMaxPartitionCount(final int partitions)
{ {
return (clusterBy, aggregate) -> return (clusterBy, aggregate) -> new GlobalSortMaxCountShuffleSpec(clusterBy, partitions, aggregate);
new MaxCountShuffleSpec(clusterBy, maxWorkerCount, aggregate);
} }
} }

View File

@ -29,7 +29,7 @@ public interface ShuffleSpecFactory
{ {
/** /**
* Build a {@link ShuffleSpec} for given {@link ClusterBy}. The {@code aggregate} flag is used to populate * Build a {@link ShuffleSpec} for given {@link ClusterBy}. The {@code aggregate} flag is used to populate
* {@link ShuffleSpec#doesAggregateByClusterKey()}. * {@link ShuffleSpec#doesAggregate()}.
*/ */
ShuffleSpec build(ClusterBy clusterBy, boolean aggregate); ShuffleSpec build(ClusterBy clusterBy, boolean aggregate);
} }

View File

@ -21,8 +21,6 @@ package org.apache.druid.msq.querykit.common;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.HeapMemoryAllocator;
import org.apache.druid.frame.channel.FrameWithPartition; import org.apache.druid.frame.channel.FrameWithPartition;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
@ -33,7 +31,6 @@ import org.apache.druid.frame.processor.ReturnOrAwait;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.Cursor; import org.apache.druid.segment.Cursor;
@ -47,8 +44,10 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
private final ReadableFrameChannel inputChannel; private final ReadableFrameChannel inputChannel;
private final WritableFrameChannel outputChannel; private final WritableFrameChannel outputChannel;
private final FrameReader frameReader; private final FrameReader frameReader;
private final FrameWriterFactory frameWriterFactory;
private final long offset; private final long offset;
private final long limit; private final long limit;
private final boolean inputSignatureMatchesOutputSignature;
long rowsProcessedSoFar = 0L; long rowsProcessedSoFar = 0L;
@ -56,6 +55,7 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
ReadableFrameChannel inputChannel, ReadableFrameChannel inputChannel,
WritableFrameChannel outputChannel, WritableFrameChannel outputChannel,
FrameReader frameReader, FrameReader frameReader,
FrameWriterFactory frameWriterFactory,
long offset, long offset,
long limit long limit
) )
@ -63,8 +63,10 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
this.inputChannel = inputChannel; this.inputChannel = inputChannel;
this.outputChannel = outputChannel; this.outputChannel = outputChannel;
this.frameReader = frameReader; this.frameReader = frameReader;
this.frameWriterFactory = frameWriterFactory;
this.offset = offset; this.offset = offset;
this.limit = limit; this.limit = limit;
this.inputSignatureMatchesOutputSignature = frameReader.signature().equals(frameWriterFactory.signature());
if (offset < 0 || limit < 0) { if (offset < 0 || limit < 0) {
throw new ISE("Offset and limit must be nonnegative"); throw new ISE("Offset and limit must be nonnegative");
@ -130,31 +132,25 @@ public class OffsetLimitFrameProcessor implements FrameProcessor<Long>
// Offset is past the end of the frame; skip it. // Offset is past the end of the frame; skip it.
rowsProcessedSoFar += frame.numRows(); rowsProcessedSoFar += frame.numRows();
return null; return null;
} else if (startRow == 0 && endRow == frame.numRows()) { } else if (startRow == 0
&& endRow == frame.numRows()
&& inputSignatureMatchesOutputSignature
&& frameWriterFactory.frameType().equals(frame.type())) {
// Want the whole frame; emit it as-is.
rowsProcessedSoFar += frame.numRows(); rowsProcessedSoFar += frame.numRows();
return frame; return frame;
} }
final Cursor cursor = FrameProcessors.makeCursor(frame, frameReader); final Cursor cursor = FrameProcessors.makeCursor(frame, frameReader);
// Using an unlimited memory allocator to make sure that atleast a single frame can always be generated
final HeapMemoryAllocator unlimitedAllocator = HeapMemoryAllocator.unlimited();
long rowsProcessedSoFarInFrame = 0; long rowsProcessedSoFarInFrame = 0;
final FrameWriterFactory frameWriterFactory = FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED,
unlimitedAllocator,
frameReader.signature(),
Collections.emptyList()
);
try (final FrameWriter frameWriter = frameWriterFactory.newFrameWriter(cursor.getColumnSelectorFactory())) { try (final FrameWriter frameWriter = frameWriterFactory.newFrameWriter(cursor.getColumnSelectorFactory())) {
while (!cursor.isDone() && rowsProcessedSoFarInFrame < endRow) { while (!cursor.isDone() && rowsProcessedSoFarInFrame < endRow) {
if (rowsProcessedSoFarInFrame >= startRow && !frameWriter.addSelection()) { if (rowsProcessedSoFarInFrame >= startRow && !frameWriter.addSelection()) {
// Don't retry; it can't work because the allocator is unlimited anyway. // Don't retry; it can't work because the allocator is unlimited anyway.
// Also, I don't think this line can be reached, because the allocator is unlimited. // Also, I don't think this line can be reached, because the allocator is unlimited.
throw new FrameRowTooLargeException(unlimitedAllocator.capacity()); throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
} }
cursor.advance(); cursor.advance();

View File

@ -25,6 +25,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators; import com.google.common.collect.Iterators;
import org.apache.druid.frame.allocation.HeapMemoryAllocator;
import org.apache.druid.frame.channel.ReadableConcatFrameChannel; import org.apache.druid.frame.channel.ReadableConcatFrameChannel;
import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannel; import org.apache.druid.frame.processor.OutputChannel;
@ -122,10 +123,12 @@ public class OffsetLimitFrameProcessorFactory extends BaseFrameProcessorFactory
} }
// Note: OffsetLimitFrameProcessor does not use allocator from the outputChannel; it uses unlimited instead. // Note: OffsetLimitFrameProcessor does not use allocator from the outputChannel; it uses unlimited instead.
// This ensures that a single, limited output frame can always be generated from an input frame.
return new OffsetLimitFrameProcessor( return new OffsetLimitFrameProcessor(
ReadableConcatFrameChannel.open(Iterators.transform(readableInputs.iterator(), ReadableInput::getChannel)), ReadableConcatFrameChannel.open(Iterators.transform(readableInputs.iterator(), ReadableInput::getChannel)),
outputChannel.getWritableChannel(), outputChannel.getWritableChannel(),
readableInputs.frameReader(), readableInputs.frameReader(),
stageDefinition.createFrameWriterFactory(HeapMemoryAllocator.unlimited()),
offset, offset,
// Limit processor will add limit + offset at various points; must avoid overflow // Limit processor will add limit + offset at various points; must avoid overflow
limit == null ? Long.MAX_VALUE - offset : limit limit == null ? Long.MAX_VALUE - offset : limit

View File

@ -0,0 +1,277 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.querykit.common;
import com.fasterxml.jackson.annotation.JacksonInject;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannel;
import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.OutputChannels;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.msq.counters.CounterTracker;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSliceReader;
import org.apache.druid.msq.input.InputSlices;
import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.msq.input.stage.StageInputSlice;
import org.apache.druid.msq.kernel.FrameContext;
import org.apache.druid.msq.kernel.ProcessorsAndChannels;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.querykit.BaseFrameProcessorFactory;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.Equality;
import org.apache.druid.segment.join.JoinConditionAnalysis;
import org.apache.druid.segment.join.JoinType;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
/**
* Factory for {@link SortMergeJoinFrameProcessor}, which does a sort-merge join of two inputs.
*/
@JsonTypeName("sortMergeJoin")
public class SortMergeJoinFrameProcessorFactory extends BaseFrameProcessorFactory
{
private static final int LEFT = 0;
private static final int RIGHT = 1;
private final String rightPrefix;
private final JoinConditionAnalysis condition;
private final JoinType joinType;
public SortMergeJoinFrameProcessorFactory(
final String rightPrefix,
final JoinConditionAnalysis condition,
final JoinType joinType
)
{
this.rightPrefix = Preconditions.checkNotNull(rightPrefix, "rightPrefix");
this.condition = validateCondition(Preconditions.checkNotNull(condition, "condition"));
this.joinType = Preconditions.checkNotNull(joinType, "joinType");
}
@JsonCreator
public static SortMergeJoinFrameProcessorFactory create(
@JsonProperty("rightPrefix") String rightPrefix,
@JsonProperty("condition") String condition,
@JsonProperty("joinType") JoinType joinType,
@JacksonInject ExprMacroTable macroTable
)
{
return new SortMergeJoinFrameProcessorFactory(
StringUtils.nullToEmptyNonDruidDataString(rightPrefix),
JoinConditionAnalysis.forExpression(
Preconditions.checkNotNull(condition, "condition"),
StringUtils.nullToEmptyNonDruidDataString(rightPrefix),
macroTable
),
joinType
);
}
@JsonProperty
public String getRightPrefix()
{
return rightPrefix;
}
@JsonProperty
public String getCondition()
{
return condition.getOriginalExpression();
}
@JsonProperty
public JoinType getJoinType()
{
return joinType;
}
@Override
public ProcessorsAndChannels<FrameProcessor<Long>, Long> makeProcessors(
StageDefinition stageDefinition,
int workerNumber,
List<InputSlice> inputSlices,
InputSliceReader inputSliceReader,
@Nullable Object extra,
OutputChannelFactory outputChannelFactory,
FrameContext frameContext,
int maxOutstandingProcessors,
CounterTracker counters,
Consumer<Throwable> warningPublisher
) throws IOException
{
if (inputSlices.size() != 2 || !inputSlices.stream().allMatch(slice -> slice instanceof StageInputSlice)) {
// Can't hit this unless there was some bug in QueryKit.
throw new ISE("Expected two stage inputs");
}
// Compute key columns.
final List<List<KeyColumn>> keyColumns = toKeyColumns(condition);
// Stitch up the inputs and validate each input channel signature.
// If validateInputFrameSignatures fails, it's a precondition violation: this class somehow got bad inputs.
final Int2ObjectMap<List<ReadableInput>> inputsByPartition = validateInputFrameSignatures(
InputSlices.attachAndCollectPartitions(
inputSlices,
inputSliceReader,
counters,
warningPublisher
),
keyColumns
);
if (inputsByPartition.isEmpty()) {
return new ProcessorsAndChannels<>(Sequences.empty(), OutputChannels.none());
}
// Create output channels.
final Int2ObjectMap<OutputChannel> outputChannels = new Int2ObjectAVLTreeMap<>();
for (int partitionNumber : inputsByPartition.keySet()) {
outputChannels.put(partitionNumber, outputChannelFactory.openChannel(partitionNumber));
}
// Create processors.
final Iterable<FrameProcessor<Long>> processors = Iterables.transform(
inputsByPartition.int2ObjectEntrySet(),
entry -> {
final int partitionNumber = entry.getIntKey();
final List<ReadableInput> readableInputs = entry.getValue();
final OutputChannel outputChannel = outputChannels.get(partitionNumber);
return new SortMergeJoinFrameProcessor(
readableInputs.get(LEFT),
readableInputs.get(RIGHT),
outputChannel.getWritableChannel(),
stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()),
rightPrefix,
keyColumns,
joinType
);
}
);
return new ProcessorsAndChannels<>(
Sequences.simple(processors),
OutputChannels.wrap(ImmutableList.copyOf(outputChannels.values()))
);
}
/**
* Extracts key columns from a {@link JoinConditionAnalysis}. The returned list has two elements: 0 is the
* left-hand side, 1 is the right-hand side. Each sub-list has one element for each equi-condition.
*
* The condition must have been validated by {@link #validateCondition(JoinConditionAnalysis)}.
*/
public static List<List<KeyColumn>> toKeyColumns(final JoinConditionAnalysis condition)
{
final List<List<KeyColumn>> retVal = new ArrayList<>();
retVal.add(new ArrayList<>()); // Left-side key columns
retVal.add(new ArrayList<>()); // Right-side key columns
for (final Equality equiCondition : condition.getEquiConditions()) {
final String leftColumn = Preconditions.checkNotNull(
equiCondition.getLeftExpr().getBindingIfIdentifier(),
"leftExpr#getBindingIfIdentifier"
);
retVal.get(0).add(new KeyColumn(leftColumn, KeyOrder.ASCENDING));
retVal.get(1).add(new KeyColumn(equiCondition.getRightColumn(), KeyOrder.ASCENDING));
}
return retVal;
}
/**
* Validates that a join condition can be handled by this processor. Returns the condition if it can be handled.
* Throws {@link IllegalArgumentException} if the condition cannot be handled.
*/
public static JoinConditionAnalysis validateCondition(final JoinConditionAnalysis condition)
{
if (condition.isAlwaysTrue()) {
return condition;
}
if (condition.isAlwaysFalse()) {
throw new IAE("Cannot handle constant condition: %s", condition.getOriginalExpression());
}
if (condition.getNonEquiConditions().size() > 0) {
throw new IAE("Cannot handle non-equijoin condition: %s", condition.getOriginalExpression());
}
if (condition.getEquiConditions().stream().anyMatch(c -> !c.getLeftExpr().isIdentifier())) {
throw new IAE(
"Cannot handle equality condition involving left-hand expression: %s",
condition.getOriginalExpression()
);
}
return condition;
}
/**
* Validates that all signatures from {@link InputSlices#attachAndCollectPartitions} are prefixed by the
* provided {@code keyColumns}.
*/
private static Int2ObjectMap<List<ReadableInput>> validateInputFrameSignatures(
final Int2ObjectMap<List<ReadableInput>> inputsByPartition,
final List<List<KeyColumn>> keyColumns
)
{
for (List<ReadableInput> readableInputs : inputsByPartition.values()) {
for (int i = 0; i < readableInputs.size(); i++) {
final ReadableInput readableInput = readableInputs.get(i);
Preconditions.checkState(readableInput.hasChannel(), "readableInput[%s].hasChannel", i);
final RowSignature signature = readableInput.getChannelFrameReader().signature();
for (int j = 0; j < keyColumns.get(i).size(); j++) {
final String columnName = keyColumns.get(i).get(j).columnName();
Preconditions.checkState(
columnName.equals(signature.getColumnName(j)),
"readableInput[%s] column[%s] has expected name[%s]",
i,
j,
columnName
);
}
}
}
return inputsByPartition;
}
}

View File

@ -22,12 +22,9 @@ package org.apache.druid.msq.querykit.groupby;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.channel.FrameWithPartition; import org.apache.druid.frame.channel.FrameWithPartition;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.FrameProcessors; import org.apache.druid.frame.processor.FrameProcessors;
import org.apache.druid.frame.processor.FrameRowTooLargeException; import org.apache.druid.frame.processor.FrameRowTooLargeException;
@ -35,7 +32,6 @@ import org.apache.druid.frame.processor.ReturnOrAwait;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.msq.querykit.QueryKitUtils; import org.apache.druid.msq.querykit.QueryKitUtils;
import org.apache.druid.query.aggregation.AggregatorFactory; import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.PostAggregator; import org.apache.druid.query.aggregation.PostAggregator;
@ -68,10 +64,8 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
private final GroupByQuery query; private final GroupByQuery query;
private final ReadableFrameChannel inputChannel; private final ReadableFrameChannel inputChannel;
private final WritableFrameChannel outputChannel; private final WritableFrameChannel outputChannel;
private final MemoryAllocator allocator; private final FrameWriterFactory frameWriterFactory;
private final FrameReader frameReader; private final FrameReader frameReader;
private final RowSignature resultSignature;
private final ClusterBy clusterBy;
private final ColumnSelectorFactory columnSelectorFactoryForFrameWriter; private final ColumnSelectorFactory columnSelectorFactoryForFrameWriter;
private final Comparator<ResultRow> compareFn; private final Comparator<ResultRow> compareFn;
private final BinaryOperator<ResultRow> mergeFn; private final BinaryOperator<ResultRow> mergeFn;
@ -90,10 +84,8 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
final GroupByStrategySelector strategySelector, final GroupByStrategySelector strategySelector,
final ReadableFrameChannel inputChannel, final ReadableFrameChannel inputChannel,
final WritableFrameChannel outputChannel, final WritableFrameChannel outputChannel,
final FrameWriterFactory frameWriterFactory,
final FrameReader frameReader, final FrameReader frameReader,
final RowSignature resultSignature,
final ClusterBy clusterBy,
final MemoryAllocator allocator,
final ObjectMapper jsonMapper final ObjectMapper jsonMapper
) )
{ {
@ -101,9 +93,7 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
this.inputChannel = inputChannel; this.inputChannel = inputChannel;
this.outputChannel = outputChannel; this.outputChannel = outputChannel;
this.frameReader = frameReader; this.frameReader = frameReader;
this.resultSignature = resultSignature; this.frameWriterFactory = frameWriterFactory;
this.clusterBy = clusterBy;
this.allocator = allocator;
this.compareFn = strategySelector.strategize(query).createResultComparator(query); this.compareFn = strategySelector.strategize(query).createResultComparator(query);
this.mergeFn = strategySelector.strategize(query).createMergeFn(query); this.mergeFn = strategySelector.strategize(query).createMergeFn(query);
this.finalizeFn = makeFinalizeFn(query); this.finalizeFn = makeFinalizeFn(query);
@ -249,10 +239,10 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
outputRow = null; outputRow = null;
return true; return true;
} else { } else {
throw new FrameRowTooLargeException(allocator.capacity()); throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
} }
} else { } else {
throw new FrameRowTooLargeException(allocator.capacity()); throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
} }
} }
@ -269,8 +259,6 @@ public class GroupByPostShuffleFrameProcessor implements FrameProcessor<Long>
private void setUpFrameWriterIfNeeded() private void setUpFrameWriterIfNeeded()
{ {
if (frameWriter == null) { if (frameWriter == null) {
final FrameWriterFactory frameWriterFactory =
FrameWriters.makeFrameWriterFactory(FrameType.ROW_BASED, allocator, resultSignature, clusterBy.getColumns());
frameWriter = frameWriterFactory.newFrameWriter(columnSelectorFactoryForFrameWriter); frameWriter = frameWriterFactory.newFrameWriter(columnSelectorFactoryForFrameWriter);
} }
} }

View File

@ -24,8 +24,8 @@ import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName; import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables; import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannel; import org.apache.druid.frame.processor.OutputChannel;
import org.apache.druid.frame.processor.OutputChannelFactory; import org.apache.druid.frame.processor.OutputChannelFactory;
@ -86,8 +86,8 @@ public class GroupByPostShuffleFrameProcessorFactory extends BaseFrameProcessorF
// Expecting a single input slice from some prior stage. // Expecting a single input slice from some prior stage.
final StageInputSlice slice = (StageInputSlice) Iterables.getOnlyElement(inputSlices); final StageInputSlice slice = (StageInputSlice) Iterables.getOnlyElement(inputSlices);
final GroupByStrategySelector strategySelector = frameContext.groupByStrategySelector(); final GroupByStrategySelector strategySelector = frameContext.groupByStrategySelector();
final Int2ObjectSortedMap<OutputChannel> outputChannels = new Int2ObjectAVLTreeMap<>();
final Int2ObjectMap<OutputChannel> outputChannels = new Int2ObjectOpenHashMap<>();
for (final ReadablePartition partition : slice.getPartitions()) { for (final ReadablePartition partition : slice.getPartitions()) {
outputChannels.computeIfAbsent( outputChannels.computeIfAbsent(
partition.getPartitionNumber(), partition.getPartitionNumber(),
@ -115,10 +115,8 @@ public class GroupByPostShuffleFrameProcessorFactory extends BaseFrameProcessorF
strategySelector, strategySelector,
readableInput.getChannel(), readableInput.getChannel(),
outputChannel.getWritableChannel(), outputChannel.getWritableChannel(),
stageDefinition.createFrameWriterFactory(outputChannel.getFrameMemoryAllocator()),
readableInput.getChannelFrameReader(), readableInput.getChannelFrameReader(),
stageDefinition.getSignature(),
stageDefinition.getClusterBy(),
outputChannel.getFrameMemoryAllocator(),
frameContext.jsonMapper() frameContext.jsonMapper()
); );
} }

View File

@ -25,19 +25,15 @@ import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import org.apache.druid.collections.ResourceHolder; import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.FrameContext;
import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory; import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory;
import org.apache.druid.msq.querykit.LazyResourceHolder; import org.apache.druid.msq.querykit.LazyResourceHolder;
import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.GroupByQuery;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.join.JoinableFactoryWrapper;
@JsonTypeName("groupByPreShuffle") @JsonTypeName("groupByPreShuffle")
@ -62,9 +58,7 @@ public class GroupByPreShuffleFrameProcessorFactory extends BaseLeafFrameProcess
final ReadableInput baseInput, final ReadableInput baseInput,
final Int2ObjectMap<ReadableInput> sideChannels, final Int2ObjectMap<ReadableInput> sideChannels,
final ResourceHolder<WritableFrameChannel> outputChannelHolder, final ResourceHolder<WritableFrameChannel> outputChannelHolder,
final ResourceHolder<MemoryAllocator> allocatorHolder, final ResourceHolder<FrameWriterFactory> frameWriterFactoryHolder,
final RowSignature signature,
final ClusterBy clusterBy,
final FrameContext frameContext final FrameContext frameContext
) )
{ {
@ -75,15 +69,7 @@ public class GroupByPreShuffleFrameProcessorFactory extends BaseLeafFrameProcess
frameContext.groupByStrategySelector(), frameContext.groupByStrategySelector(),
new JoinableFactoryWrapper(frameContext.joinableFactory()), new JoinableFactoryWrapper(frameContext.joinableFactory()),
outputChannelHolder, outputChannelHolder,
new LazyResourceHolder<>(() -> Pair.of( new LazyResourceHolder<>(() -> Pair.of(frameWriterFactoryHolder.get(), frameWriterFactoryHolder)),
FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED,
allocatorHolder.get(),
signature,
clusterBy.getColumns()
),
allocatorHolder
)),
frameContext.memoryParameters().getBroadcastJoinMemory() frameContext.memoryParameters().getBroadcastJoinMemory()
); );
} }

View File

@ -22,12 +22,13 @@ package org.apache.druid.msq.querykit.groupby;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.kernel.MaxCountShuffleSpec; import org.apache.druid.msq.kernel.MixShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinition;
@ -80,6 +81,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource(
queryKit, queryKit,
queryId, queryId,
originalQuery.context(),
originalQuery.getDataSource(), originalQuery.getDataSource(),
originalQuery.getQuerySegmentSpec(), originalQuery.getQuerySegmentSpec(),
originalQuery.getFilter(), originalQuery.getFilter(),
@ -118,7 +120,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.singlePartition(); shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.singlePartition();
shuffleSpecFactoryPostAggregation = ShuffleSpecFactories.singlePartition(); shuffleSpecFactoryPostAggregation = ShuffleSpecFactories.singlePartition();
} else if (doOrderBy) { } else if (doOrderBy) {
shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.subQueryWithMaxWorkerCount(maxWorkerCount); shuffleSpecFactoryPreAggregation = ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount);
shuffleSpecFactoryPostAggregation = doLimitOrOffset shuffleSpecFactoryPostAggregation = doLimitOrOffset
? ShuffleSpecFactories.singlePartition() ? ShuffleSpecFactories.singlePartition()
: resultShuffleSpecFactory; : resultShuffleSpecFactory;
@ -162,7 +164,7 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
.inputs(new StageInputSpec(firstStageNumber + 1)) .inputs(new StageInputSpec(firstStageNumber + 1))
.signature(resultSignature) .signature(resultSignature)
.maxWorkerCount(1) .maxWorkerCount(1)
.shuffleSpec(new MaxCountShuffleSpec(ClusterBy.none(), 1, false)) .shuffleSpec(MixShuffleSpec.instance())
.processorFactory( .processorFactory(
new OffsetLimitFrameProcessorFactory( new OffsetLimitFrameProcessorFactory(
limitSpec.getOffset(), limitSpec.getOffset(),
@ -221,10 +223,10 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
*/ */
static ClusterBy computeIntermediateClusterBy(final GroupByQuery query) static ClusterBy computeIntermediateClusterBy(final GroupByQuery query)
{ {
final List<SortColumn> columns = new ArrayList<>(); final List<KeyColumn> columns = new ArrayList<>();
for (final DimensionSpec dimension : query.getDimensions()) { for (final DimensionSpec dimension : query.getDimensions()) {
columns.add(new SortColumn(dimension.getOutputName(), false)); columns.add(new KeyColumn(dimension.getOutputName(), KeyOrder.ASCENDING));
} }
// Note: ignoring time because we assume granularity = all. // Note: ignoring time because we assume granularity = all.
@ -240,13 +242,15 @@ public class GroupByQueryKit implements QueryKit<GroupByQuery>
final DefaultLimitSpec defaultLimitSpec = (DefaultLimitSpec) query.getLimitSpec(); final DefaultLimitSpec defaultLimitSpec = (DefaultLimitSpec) query.getLimitSpec();
if (!defaultLimitSpec.getColumns().isEmpty()) { if (!defaultLimitSpec.getColumns().isEmpty()) {
final List<SortColumn> clusterByColumns = new ArrayList<>(); final List<KeyColumn> clusterByColumns = new ArrayList<>();
for (final OrderByColumnSpec orderBy : defaultLimitSpec.getColumns()) { for (final OrderByColumnSpec orderBy : defaultLimitSpec.getColumns()) {
clusterByColumns.add( clusterByColumns.add(
new SortColumn( new KeyColumn(
orderBy.getDimension(), orderBy.getDimension(),
orderBy.getDirection() == OrderByColumnSpec.Direction.DESCENDING orderBy.getDirection() == OrderByColumnSpec.Direction.DESCENDING
? KeyOrder.DESCENDING
: KeyOrder.ASCENDING
) )
); );
} }

View File

@ -25,19 +25,15 @@ import com.fasterxml.jackson.annotation.JsonTypeName;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import org.apache.druid.collections.ResourceHolder; import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryAllocator;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.processor.FrameProcessor; import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.msq.input.ReadableInput; import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.msq.kernel.FrameContext; import org.apache.druid.msq.kernel.FrameContext;
import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory; import org.apache.druid.msq.querykit.BaseLeafFrameProcessorFactory;
import org.apache.druid.msq.querykit.LazyResourceHolder; import org.apache.druid.msq.querykit.LazyResourceHolder;
import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.join.JoinableFactoryWrapper; import org.apache.druid.segment.join.JoinableFactoryWrapper;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -78,9 +74,7 @@ public class ScanQueryFrameProcessorFactory extends BaseLeafFrameProcessorFactor
ReadableInput baseInput, ReadableInput baseInput,
Int2ObjectMap<ReadableInput> sideChannels, Int2ObjectMap<ReadableInput> sideChannels,
ResourceHolder<WritableFrameChannel> outputChannelHolder, ResourceHolder<WritableFrameChannel> outputChannelHolder,
ResourceHolder<MemoryAllocator> allocatorHolder, ResourceHolder<FrameWriterFactory> frameWriterFactoryHolder,
RowSignature signature,
ClusterBy clusterBy,
FrameContext frameContext FrameContext frameContext
) )
{ {
@ -90,15 +84,7 @@ public class ScanQueryFrameProcessorFactory extends BaseLeafFrameProcessorFactor
sideChannels, sideChannels,
new JoinableFactoryWrapper(frameContext.joinableFactory()), new JoinableFactoryWrapper(frameContext.joinableFactory()),
outputChannelHolder, outputChannelHolder,
new LazyResourceHolder<>(() -> Pair.of( new LazyResourceHolder<>(() -> Pair.of(frameWriterFactoryHolder.get(), frameWriterFactoryHolder)),
FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED,
allocatorHolder.get(),
signature,
clusterBy.getColumns()
),
allocatorHolder
)),
runningCountForLimit, runningCountForLimit,
frameContext.memoryParameters().getBroadcastJoinMemory(), frameContext.memoryParameters().getBroadcastJoinMemory(),
frameContext.jsonMapper() frameContext.jsonMapper()

View File

@ -22,10 +22,11 @@ package org.apache.druid.msq.querykit.scan;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.java.util.common.granularity.Granularity; import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.kernel.MaxCountShuffleSpec; import org.apache.druid.msq.kernel.MixShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.ShuffleSpec; import org.apache.druid.msq.kernel.ShuffleSpec;
@ -70,8 +71,8 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
/** /**
* We ignore the resultShuffleSpecFactory in case: * We ignore the resultShuffleSpecFactory in case:
* 1. There is no cluster by * 1. There is no cluster by
* 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec * 2. This is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
*/ */
// No ordering, but there is a limit or an offset. These work by funneling everything through a single partition. // No ordering, but there is a limit or an offset. These work by funneling everything through a single partition.
// So there is no point in forcing any particular partitioning. Since everything is funneled into a single // So there is no point in forcing any particular partitioning. Since everything is funneled into a single
@ -90,6 +91,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource( final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource(
queryKit, queryKit,
queryId, queryId,
originalQuery.context(),
originalQuery.getDataSource(), originalQuery.getDataSource(),
originalQuery.getQuerySegmentSpec(), originalQuery.getQuerySegmentSpec(),
originalQuery.getFilter(), originalQuery.getFilter(),
@ -112,26 +114,26 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
// 1. There is no cluster by // 1. There is no cluster by
// 2. There is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec // 2. There is an offset which means everything gets funneled into a single partition hence we use MaxCountShuffleSpec
if (queryToRun.getOrderBys().isEmpty() && hasLimitOrOffset) { if (queryToRun.getOrderBys().isEmpty() && hasLimitOrOffset) {
shuffleSpec = new MaxCountShuffleSpec(ClusterBy.none(), 1, false); shuffleSpec = MixShuffleSpec.instance();
signatureToUse = scanSignature; signatureToUse = scanSignature;
} else { } else {
final RowSignature.Builder signatureBuilder = RowSignature.builder().addAll(scanSignature); final RowSignature.Builder signatureBuilder = RowSignature.builder().addAll(scanSignature);
final Granularity segmentGranularity = final Granularity segmentGranularity =
QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, queryToRun.getContext()); QueryKitUtils.getSegmentGranularityFromContext(jsonMapper, queryToRun.getContext());
final List<SortColumn> clusterByColumns = new ArrayList<>(); final List<KeyColumn> clusterByColumns = new ArrayList<>();
// Add regular orderBys. // Add regular orderBys.
for (final ScanQuery.OrderBy orderBy : queryToRun.getOrderBys()) { for (final ScanQuery.OrderBy orderBy : queryToRun.getOrderBys()) {
clusterByColumns.add( clusterByColumns.add(
new SortColumn( new KeyColumn(
orderBy.getColumnName(), orderBy.getColumnName(),
orderBy.getOrder() == ScanQuery.Order.DESCENDING orderBy.getOrder() == ScanQuery.Order.DESCENDING ? KeyOrder.DESCENDING : KeyOrder.ASCENDING
) )
); );
} }
// Add partition boosting column. // Add partition boosting column.
clusterByColumns.add(new SortColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, false)); clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING));
signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG); signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG);
final ClusterBy clusterBy = final ClusterBy clusterBy =
@ -159,7 +161,7 @@ public class ScanQueryKit implements QueryKit<ScanQuery>
.inputs(new StageInputSpec(firstStageNumber)) .inputs(new StageInputSpec(firstStageNumber))
.signature(signatureToUse) .signature(signatureToUse)
.maxWorkerCount(1) .maxWorkerCount(1)
.shuffleSpec(new MaxCountShuffleSpec(ClusterBy.none(), 1, false)) .shuffleSpec(MixShuffleSpec.instance())
.processorFactory( .processorFactory(
new OffsetLimitFrameProcessorFactory( new OffsetLimitFrameProcessorFactory(
queryToRun.getScanRowsOffset(), queryToRun.getScanRowsOffset(),

View File

@ -23,7 +23,6 @@ import com.google.common.base.Preconditions;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel; import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel;
import org.apache.druid.frame.processor.DurableStorageOutputChannelFactory;
import org.apache.druid.frame.util.DurableStorageUtils; import org.apache.druid.frame.util.DurableStorageUtils;
import org.apache.druid.java.util.common.IOE; import org.apache.druid.java.util.common.IOE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;

View File

@ -17,7 +17,7 @@
* under the License. * under the License.
*/ */
package org.apache.druid.frame.processor; package org.apache.druid.msq.shuffle;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers; import com.google.common.base.Suppliers;
@ -32,6 +32,9 @@ import org.apache.druid.frame.channel.ReadableInputStreamFrameChannel;
import org.apache.druid.frame.channel.WritableFrameFileChannel; import org.apache.druid.frame.channel.WritableFrameFileChannel;
import org.apache.druid.frame.file.FrameFileFooter; import org.apache.druid.frame.file.FrameFileFooter;
import org.apache.druid.frame.file.FrameFileWriter; import org.apache.druid.frame.file.FrameFileWriter;
import org.apache.druid.frame.processor.OutputChannel;
import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.PartitionedOutputChannel;
import org.apache.druid.frame.util.DurableStorageUtils; import org.apache.druid.frame.util.DurableStorageUtils;
import org.apache.druid.java.util.common.FileUtils; import org.apache.druid.java.util.common.FileUtils;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;

View File

@ -26,7 +26,6 @@ import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.query.QueryContexts; import org.apache.druid.query.QueryContexts;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.Arrays; import java.util.Arrays;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;

View File

@ -62,7 +62,6 @@ import org.apache.druid.sql.calcite.table.RowSignatures;
import org.joda.time.Interval; import org.joda.time.Interval;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;

View File

@ -108,6 +108,7 @@ public class MSQTaskSqlEngine implements SqlEngine
{ {
switch (feature) { switch (feature) {
case ALLOW_BINDABLE_PLAN: case ALLOW_BINDABLE_PLAN:
case ALLOW_BROADCAST_RIGHTY_JOIN:
case TIMESERIES_QUERY: case TIMESERIES_QUERY:
case TOPN_QUERY: case TOPN_QUERY:
case TIME_BOUNDARY_QUERY: case TIME_BOUNDARY_QUERY:

View File

@ -38,7 +38,6 @@ import org.apache.druid.msq.test.MSQTestBase;
import org.apache.druid.msq.test.MSQTestFileUtils; import org.apache.druid.msq.test.MSQTestFileUtils;
import org.apache.druid.query.InlineDataSource; import org.apache.druid.query.InlineDataSource;
import org.apache.druid.query.QueryDataSource; import org.apache.druid.query.QueryDataSource;
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.aggregation.CountAggregatorFactory; import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory; import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory; import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
@ -62,6 +61,8 @@ import org.apache.druid.sql.SqlPlanningException;
import org.apache.druid.sql.calcite.expression.DruidExpression; import org.apache.druid.sql.calcite.expression.DruidExpression;
import org.apache.druid.sql.calcite.external.ExternalDataSource; import org.apache.druid.sql.calcite.external.ExternalDataSource;
import org.apache.druid.sql.calcite.filtration.Filtration; import org.apache.druid.sql.calcite.filtration.Filtration;
import org.apache.druid.sql.calcite.planner.JoinAlgorithm;
import org.apache.druid.sql.calcite.planner.PlannerContext;
import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException; import org.apache.druid.sql.calcite.planner.UnsupportedSQLQueryException;
import org.apache.druid.sql.calcite.util.CalciteTests; import org.apache.druid.sql.calcite.util.CalciteTests;
import org.hamcrest.CoreMatchers; import org.hamcrest.CoreMatchers;
@ -425,8 +426,25 @@ public class MSQSelectTest extends MSQTestBase
} }
@Test @Test
public void testJoin() public void testBroadcastJoin()
{ {
testJoin(JoinAlgorithm.BROADCAST);
}
@Test
public void testSortMergeJoin()
{
testJoin(JoinAlgorithm.SORT_MERGE);
}
private void testJoin(final JoinAlgorithm joinAlgorithm)
{
final Map<String, Object> queryContext =
ImmutableMap.<String, Object>builder()
.putAll(context)
.put(PlannerContext.CTX_SQL_JOIN_ALGORITHM, joinAlgorithm.toString())
.build();
final RowSignature resultSignature = RowSignature.builder() final RowSignature resultSignature = RowSignature.builder()
.add("dim2", ColumnType.STRING) .add("dim2", ColumnType.STRING)
.add("EXPR$1", ColumnType.DOUBLE) .add("EXPR$1", ColumnType.DOUBLE)
@ -460,7 +478,7 @@ public class MSQSelectTest extends MSQTestBase
.columns("dim2", "m1", "m2") .columns("dim2", "m1", "m2")
.context( .context(
defaultScanQueryContext( defaultScanQueryContext(
context, queryContext,
RowSignature.builder() RowSignature.builder()
.add("dim2", ColumnType.STRING) .add("dim2", ColumnType.STRING)
.add("m1", ColumnType.FLOAT) .add("m1", ColumnType.FLOAT)
@ -470,6 +488,7 @@ public class MSQSelectTest extends MSQTestBase
) )
.limit(10) .limit(10)
.build() .build()
.withOverriddenContext(queryContext)
), ),
new QueryDataSource( new QueryDataSource(
newScanQueryBuilder() newScanQueryBuilder()
@ -479,11 +498,12 @@ public class MSQSelectTest extends MSQTestBase
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST) .resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context( .context(
defaultScanQueryContext( defaultScanQueryContext(
context, queryContext,
RowSignature.builder().add("m1", ColumnType.FLOAT).build() RowSignature.builder().add("m1", ColumnType.FLOAT).build()
) )
) )
.build() .build()
.withOverriddenContext(queryContext)
), ),
"j0.", "j0.",
equalsCondition( equalsCondition(
@ -525,10 +545,9 @@ public class MSQSelectTest extends MSQTestBase
new FieldAccessPostAggregator(null, "a0:count") new FieldAccessPostAggregator(null, "a0:count")
) )
) )
) )
) )
.setContext(context) .setContext(queryContext)
.build(); .build();
testSelectQuery() testSelectQuery()
@ -542,164 +561,23 @@ public class MSQSelectTest extends MSQTestBase
.setExpectedMSQSpec( .setExpectedMSQSpec(
MSQSpec.builder() MSQSpec.builder()
.query(query) .query(query)
.columnMappings(new ColumnMappings(ImmutableList.of( .columnMappings(
new ColumnMapping("d0", "dim2"), new ColumnMappings(
new ColumnMapping("a0", "EXPR$1") ImmutableList.of(
))) new ColumnMapping("d0", "dim2"),
new ColumnMapping("a0", "EXPR$1")
)
)
)
.tuningConfig(MSQTuningConfig.defaultConfig()) .tuningConfig(MSQTuningConfig.defaultConfig())
.build() .build()
) )
.setExpectedRowSignature(resultSignature) .setExpectedRowSignature(resultSignature)
.setExpectedResultRows(expectedResults) .setExpectedResultRows(expectedResults)
.setQueryContext(context) .setQueryContext(queryContext)
.setExpectedCountersForStageWorkerChannel( .setExpectedCountersForStageWorkerChannel(CounterSnapshotMatcher.with().totalFiles(1), 0, 0, "input0")
CounterSnapshotMatcher .setExpectedCountersForStageWorkerChannel(CounterSnapshotMatcher.with().rows(6).frames(1), 0, 0, "output")
.with().totalFiles(1), .setExpectedCountersForStageWorkerChannel(CounterSnapshotMatcher.with().rows(6).frames(1), 0, 0, "shuffle")
0, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "shuffle"
)
.verifyResults();
}
@Test
public void testBroadcastJoin()
{
final RowSignature resultSignature = RowSignature.builder()
.add("dim2", ColumnType.STRING)
.add("EXPR$1", ColumnType.DOUBLE)
.build();
final ImmutableList<Object[]> expectedResults;
if (NullHandling.sqlCompatible()) {
expectedResults = ImmutableList.of(
new Object[]{null, 4.0},
new Object[]{"", 3.0},
new Object[]{"a", 2.5},
new Object[]{"abc", 5.0}
);
} else {
expectedResults = ImmutableList.of(
new Object[]{null, 3.6666666666666665},
new Object[]{"a", 2.5},
new Object[]{"abc", 5.0}
);
}
final GroupByQuery query =
GroupByQuery.builder()
.setDataSource(
join(
new TableDataSource(CalciteTests.DATASOURCE1),
new QueryDataSource(
newScanQueryBuilder()
.dataSource(CalciteTests.DATASOURCE1)
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("dim2", "m1", "m2")
.context(
defaultScanQueryContext(
context,
RowSignature.builder()
.add("dim2", ColumnType.STRING)
.add("m1", ColumnType.FLOAT)
.add("m2", ColumnType.DOUBLE)
.build()
)
)
.limit(10)
.build()
),
"j0.",
equalsCondition(
DruidExpression.ofColumn(ColumnType.FLOAT, "m1"),
DruidExpression.ofColumn(ColumnType.FLOAT, "j0.m1")
),
JoinType.INNER
)
)
.setInterval(querySegmentSpec(Filtration.eternity()))
.setDimensions(new DefaultDimensionSpec("j0.dim2", "d0", ColumnType.STRING))
.setGranularity(Granularities.ALL)
.setAggregatorSpecs(
useDefault
? aggregators(
new DoubleSumAggregatorFactory("a0:sum", "j0.m2"),
new CountAggregatorFactory("a0:count")
)
: aggregators(
new DoubleSumAggregatorFactory("a0:sum", "j0.m2"),
new FilteredAggregatorFactory(
new CountAggregatorFactory("a0:count"),
not(selector("j0.m2", null, null)),
// Not sure why the name is only set in SQL-compatible null mode. Seems strange.
// May be due to JSON serialization: name is set on the serialized aggregator even
// if it was originally created with no name.
NullHandling.sqlCompatible() ? "a0:count" : null
)
)
)
.setPostAggregatorSpecs(
ImmutableList.of(
new ArithmeticPostAggregator(
"a0",
"quotient",
ImmutableList.of(
new FieldAccessPostAggregator(null, "a0:sum"),
new FieldAccessPostAggregator(null, "a0:count")
)
)
)
)
.setContext(context)
.build();
testSelectQuery()
.setSql(
"SELECT t1.dim2, AVG(t1.m2) FROM "
+ "foo "
+ "INNER JOIN (SELECT * FROM foo LIMIT 10) AS t1 "
+ "ON t1.m1 = foo.m1 "
+ "GROUP BY t1.dim2"
)
.setExpectedMSQSpec(
MSQSpec.builder()
.query(query)
.columnMappings(new ColumnMappings(ImmutableList.of(
new ColumnMapping("d0", "dim2"),
new ColumnMapping("a0", "EXPR$1")
)))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build()
)
.setExpectedRowSignature(resultSignature)
.setExpectedResultRows(expectedResults)
.setQueryContext(context)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().totalFiles(1),
0, 0, "input0"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "output"
)
.setExpectedCountersForStageWorkerChannel(
CounterSnapshotMatcher
.with().rows(6).frames(1),
0, 0, "shuffle"
)
.verifyResults(); .verifyResults();
} }

View File

@ -32,19 +32,19 @@ public class WorkerMemoryParametersTest
@Test @Test
public void test_oneWorkerInJvm_alone() public void test_oneWorkerInJvm_alone()
{ {
Assert.assertEquals(parameters(1, 41, 224_785_000, 100_650_000, 75_000_000), compute(1_000_000_000, 1, 1, 1, 0)); Assert.assertEquals(params(1, 41, 224_785_000, 100_650_000, 75_000_000), create(1_000_000_000, 1, 1, 1, 0, 0));
Assert.assertEquals(parameters(2, 13, 149_410_000, 66_900_000, 75_000_000), compute(1_000_000_000, 1, 2, 1, 0)); Assert.assertEquals(params(2, 13, 149_410_000, 66_900_000, 75_000_000), create(1_000_000_000, 1, 2, 1, 0, 0));
Assert.assertEquals(parameters(4, 3, 89_110_000, 39_900_000, 75_000_000), compute(1_000_000_000, 1, 4, 1, 0)); Assert.assertEquals(params(4, 3, 89_110_000, 39_900_000, 75_000_000), create(1_000_000_000, 1, 4, 1, 0, 0));
Assert.assertEquals(parameters(3, 2, 48_910_000, 21_900_000, 75_000_000), compute(1_000_000_000, 1, 8, 1, 0)); Assert.assertEquals(params(3, 2, 48_910_000, 21_900_000, 75_000_000), create(1_000_000_000, 1, 8, 1, 0, 0));
Assert.assertEquals(parameters(2, 2, 33_448_460, 14_976_922, 75_000_000), compute(1_000_000_000, 1, 12, 1, 0)); Assert.assertEquals(params(2, 2, 33_448_460, 14_976_922, 75_000_000), create(1_000_000_000, 1, 12, 1, 0, 0));
final MSQException e = Assert.assertThrows( final MSQException e = Assert.assertThrows(
MSQException.class, MSQException.class,
() -> compute(1_000_000_000, 1, 32, 1, 0) () -> create(1_000_000_000, 1, 32, 1, 0, 0)
); );
Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault()); Assert.assertEquals(new NotEnoughMemoryFault(1_588_044_000, 1_000_000_000, 750_000_000, 1, 32), e.getFault());
final MSQFault fault = Assert.assertThrows(MSQException.class, () -> compute(1_000_000_000, 2, 32, 1, 0)) final MSQFault fault = Assert.assertThrows(MSQException.class, () -> create(1_000_000_000, 2, 32, 1, 0, 0))
.getFault(); .getFault();
Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault); Assert.assertEquals(new NotEnoughMemoryFault(2024045333, 1_000_000_000, 750_000_000, 2, 32), fault);
@ -54,12 +54,12 @@ public class WorkerMemoryParametersTest
@Test @Test
public void test_oneWorkerInJvm_twoHundredWorkersInCluster() public void test_oneWorkerInJvm_twoHundredWorkersInCluster()
{ {
Assert.assertEquals(parameters(1, 83, 317_580_000, 142_200_000, 150_000_000), compute(2_000_000_000, 1, 1, 200, 0)); Assert.assertEquals(params(1, 83, 317_580_000, 142_200_000, 150_000_000), create(2_000_000_000, 1, 1, 200, 0, 0));
Assert.assertEquals(parameters(2, 27, 166_830_000, 74_700_000, 150_000_000), compute(2_000_000_000, 1, 2, 200, 0)); Assert.assertEquals(params(2, 27, 166_830_000, 74_700_000, 150_000_000), create(2_000_000_000, 1, 2, 200, 0, 0));
final MSQException e = Assert.assertThrows( final MSQException e = Assert.assertThrows(
MSQException.class, MSQException.class,
() -> compute(1_000_000_000, 1, 4, 200, 0) () -> create(1_000_000_000, 1, 4, 200, 0, 0)
); );
Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault()); Assert.assertEquals(new TooManyWorkersFault(200, 109), e.getFault());
@ -68,39 +68,69 @@ public class WorkerMemoryParametersTest
@Test @Test
public void test_fourWorkersInJvm_twoHundredWorkersInCluster() public void test_fourWorkersInJvm_twoHundredWorkersInCluster()
{ {
Assert.assertEquals( Assert.assertEquals(params(1, 150, 679_380_000, 304_200_000, 168_750_000), create(9_000_000_000L, 4, 1, 200, 0, 0));
parameters(1, 150, 679_380_000, 304_200_000, 168_750_000), Assert.assertEquals(params(2, 62, 543_705_000, 243_450_000, 168_750_000), create(9_000_000_000L, 4, 2, 200, 0, 0));
compute(9_000_000_000L, 4, 1, 200, 0) Assert.assertEquals(params(4, 22, 374_111_250, 167_512_500, 168_750_000), create(9_000_000_000L, 4, 4, 200, 0, 0));
); Assert.assertEquals(params(4, 14, 204_517_500, 91_575_000, 168_750_000), create(9_000_000_000L, 4, 8, 200, 0, 0));
Assert.assertEquals( Assert.assertEquals(params(4, 8, 68_842_500, 30_825_000, 168_750_000), create(9_000_000_000L, 4, 16, 200, 0, 0));
parameters(2, 62, 543_705_000, 243_450_000, 168_750_000),
compute(9_000_000_000L, 4, 2, 200, 0)
);
Assert.assertEquals(
parameters(4, 22, 374_111_250, 167_512_500, 168_750_000),
compute(9_000_000_000L, 4, 4, 200, 0)
);
Assert.assertEquals(parameters(4, 14, 204_517_500, 91_575_000, 168_750_000), compute(9_000_000_000L, 4, 8, 200, 0));
Assert.assertEquals(parameters(4, 8, 68_842_500, 30_825_000, 168_750_000), compute(9_000_000_000L, 4, 16, 200, 0));
final MSQException e = Assert.assertThrows( final MSQException e = Assert.assertThrows(
MSQException.class, MSQException.class,
() -> compute(8_000_000_000L, 4, 32, 200, 0) () -> create(8_000_000_000L, 4, 32, 200, 0, 0)
); );
Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault()); Assert.assertEquals(new TooManyWorkersFault(200, 124), e.getFault());
// Make sure 107 actually works. (Verify the error message above.) // Make sure 124 actually works, and 125 doesn't. (Verify the error message above.)
Assert.assertEquals(parameters(4, 3, 28_140_000, 12_600_000, 150_000_000), compute(8_000_000_000L, 4, 32, 107, 0)); Assert.assertEquals(params(4, 3, 16_750_000, 7_500_000, 150_000_000), create(8_000_000_000L, 4, 32, 124, 0, 0));
final MSQException e2 = Assert.assertThrows(
MSQException.class,
() -> create(8_000_000_000L, 4, 32, 125, 0, 0)
);
Assert.assertEquals(new TooManyWorkersFault(125, 124), e2.getFault());
} }
@Test @Test
public void test_oneWorkerInJvm_negativeUsableMemory() public void test_fourWorkersInJvm_twoHundredWorkersInCluster_hashPartitions()
{ {
Exception e = Assert.assertThrows( Assert.assertEquals(
IllegalArgumentException.class, params(1, 150, 545_380_000, 244_200_000, 168_750_000), create(9_000_000_000L, 4, 1, 200, 200, 0));
() -> WorkerMemoryParameters.createInstance(100, -50, 1, 32, 1) Assert.assertEquals(
params(2, 62, 409_705_000, 183_450_000, 168_750_000), create(9_000_000_000L, 4, 2, 200, 200, 0));
Assert.assertEquals(
params(4, 22, 240_111_250, 107_512_500, 168_750_000), create(9_000_000_000L, 4, 4, 200, 200, 0));
Assert.assertEquals(
params(4, 14, 70_517_500, 31_575_000, 168_750_000), create(9_000_000_000L, 4, 8, 200, 200, 0));
final MSQException e = Assert.assertThrows(
MSQException.class,
() -> create(9_000_000_000L, 4, 16, 200, 200, 0)
); );
Assert.assertEquals(new TooManyWorkersFault(200, 138), e.getFault());
// Make sure 138 actually works, and 139 doesn't. (Verify the error message above.)
Assert.assertEquals(params(4, 8, 17_922_500, 8_025_000, 168_750_000), create(9_000_000_000L, 4, 16, 138, 138, 0));
final MSQException e2 = Assert.assertThrows(
MSQException.class,
() -> create(9_000_000_000L, 4, 16, 139, 139, 0)
);
Assert.assertEquals(new TooManyWorkersFault(139, 138), e2.getFault());
}
@Test
public void test_oneWorkerInJvm_oneByteUsableMemory()
{
final MSQException e = Assert.assertThrows(
MSQException.class,
() -> WorkerMemoryParameters.createInstance(1, 1, 1, 32, 1, 1)
);
Assert.assertEquals(new NotEnoughMemoryFault(554669334, 1, 1, 1, 1), e.getFault());
} }
@Test @Test
@ -109,7 +139,7 @@ public class WorkerMemoryParametersTest
EqualsVerifier.forClass(WorkerMemoryParameters.class).usingGetClass().verify(); EqualsVerifier.forClass(WorkerMemoryParameters.class).usingGetClass().verify();
} }
private static WorkerMemoryParameters parameters( private static WorkerMemoryParameters params(
final int superSorterMaxActiveProcessors, final int superSorterMaxActiveProcessors,
final int superSorterMaxChannelsPerProcessor, final int superSorterMaxChannelsPerProcessor,
final long appenderatorMemory, final long appenderatorMemory,
@ -126,11 +156,12 @@ public class WorkerMemoryParametersTest
); );
} }
private static WorkerMemoryParameters compute( private static WorkerMemoryParameters create(
final long maxMemoryInJvm, final long maxMemoryInJvm,
final int numWorkersInJvm, final int numWorkersInJvm,
final int numProcessingThreadsInJvm, final int numProcessingThreadsInJvm,
final int numInputWorkers, final int numInputWorkers,
final int numHashOutputPartitions,
final int totalLookUpFootprint final int totalLookUpFootprint
) )
{ {
@ -139,6 +170,7 @@ public class WorkerMemoryParametersTest
numWorkersInJvm, numWorkersInJvm,
numProcessingThreadsInJvm, numProcessingThreadsInJvm,
numInputWorkers, numInputWorkers,
numHashOutputPartitions,
totalLookUpFootprint totalLookUpFootprint
); );
} }

View File

@ -31,6 +31,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
public class MSQFaultSerdeTest public class MSQFaultSerdeTest
{ {
@ -69,6 +70,7 @@ public class MSQFaultSerdeTest
assertFaultSerde(new TooManyClusteredByColumnsFault(10, 8, 1)); assertFaultSerde(new TooManyClusteredByColumnsFault(10, 8, 1));
assertFaultSerde(new TooManyInputFilesFault(15, 10, 5)); assertFaultSerde(new TooManyInputFilesFault(15, 10, 5));
assertFaultSerde(new TooManyPartitionsFault(10)); assertFaultSerde(new TooManyPartitionsFault(10));
assertFaultSerde(new TooManyRowsWithSameKeyFault(Arrays.asList("foo", 123), 1, 2));
assertFaultSerde(new TooManyWarningsFault(10, "the error")); assertFaultSerde(new TooManyWarningsFault(10, "the error"));
assertFaultSerde(new TooManyWorkersFault(10, 5)); assertFaultSerde(new TooManyWorkersFault(10, 5));
assertFaultSerde(new TooManyAttemptsForWorker(2, "taskId", 1, "rootError")); assertFaultSerde(new TooManyAttemptsForWorker(2, "taskId", 1, "rootError"));

View File

@ -24,7 +24,8 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.indexer.TaskState; import org.apache.druid.indexer.TaskState;
import org.apache.druid.indexing.common.SingleFileTaskReportFileWriter; import org.apache.druid.indexing.common.SingleFileTaskReportFileWriter;
import org.apache.druid.indexing.common.TaskReport; import org.apache.druid.indexing.common.TaskReport;
@ -35,7 +36,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.msq.indexing.error.MSQErrorReport; import org.apache.druid.msq.indexing.error.MSQErrorReport;
import org.apache.druid.msq.indexing.error.TooManyColumnsFault; import org.apache.druid.msq.indexing.error.TooManyColumnsFault;
import org.apache.druid.msq.kernel.MaxCountShuffleSpec; import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
@ -65,8 +66,8 @@ public class MSQTaskReportTest
.builder(0) .builder(0)
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 1L)) .processorFactory(new OffsetLimitFrameProcessorFactory(0, 1L))
.shuffleSpec( .shuffleSpec(
new MaxCountShuffleSpec( new GlobalSortMaxCountShuffleSpec(
new ClusterBy(ImmutableList.of(new SortColumn("s", false)), 0), new ClusterBy(ImmutableList.of(new KeyColumn("s", KeyOrder.ASCENDING)), 0),
2, 2,
false false
) )

View File

@ -23,7 +23,8 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.msq.guice.MSQIndexingModule; import org.apache.druid.msq.guice.MSQIndexingModule;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory; import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.segment.TestHelper; import org.apache.druid.segment.TestHelper;
@ -45,8 +46,8 @@ public class QueryDefinitionTest
.builder(0) .builder(0)
.processorFactory(new OffsetLimitFrameProcessorFactory(0, 1L)) .processorFactory(new OffsetLimitFrameProcessorFactory(0, 1L))
.shuffleSpec( .shuffleSpec(
new MaxCountShuffleSpec( new GlobalSortMaxCountShuffleSpec(
new ClusterBy(ImmutableList.of(new SortColumn("s", false)), 0), new ClusterBy(ImmutableList.of(new KeyColumn("s", KeyOrder.ASCENDING)), 0),
2, 2,
false false
) )

View File

@ -23,7 +23,8 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet;
import nl.jqno.equalsverifier.EqualsVerifier; import nl.jqno.equalsverifier.EqualsVerifier;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.exec.Limits; import org.apache.druid.msq.exec.Limits;
import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpec;
@ -60,7 +61,7 @@ public class StageDefinitionTest
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionsForShuffle(null)); Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionBoundariesForShuffle(null));
} }
@Test @Test
@ -72,16 +73,19 @@ public class StageDefinitionTest
ImmutableSet.of(), ImmutableSet.of(),
new OffsetLimitFrameProcessorFactory(0, 1L), new OffsetLimitFrameProcessorFactory(0, 1L),
RowSignature.empty(), RowSignature.empty(),
new MaxCountShuffleSpec(new ClusterBy(ImmutableList.of(new SortColumn("test", false)), 1), 2, false), new GlobalSortMaxCountShuffleSpec(
new ClusterBy(ImmutableList.of(new KeyColumn("test", KeyOrder.ASCENDING)), 0),
2,
false
),
1, 1,
false, false,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
); );
Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionsForShuffle(null)); Assert.assertThrows(ISE.class, () -> stageDefinition.generatePartitionBoundariesForShuffle(null));
} }
@Test @Test
public void testGeneratePartitionsForNonNullShuffleWithNonNullCollector() public void testGeneratePartitionsForNonNullShuffleWithNonNullCollector()
{ {
@ -91,7 +95,11 @@ public class StageDefinitionTest
ImmutableSet.of(), ImmutableSet.of(),
new OffsetLimitFrameProcessorFactory(0, 1L), new OffsetLimitFrameProcessorFactory(0, 1L),
RowSignature.empty(), RowSignature.empty(),
new MaxCountShuffleSpec(new ClusterBy(ImmutableList.of(new SortColumn("test", false)), 0), 1, false), new GlobalSortMaxCountShuffleSpec(
new ClusterBy(ImmutableList.of(new KeyColumn("test", KeyOrder.ASCENDING)), 0),
1,
false
),
1, 1,
false, false,
Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER Limits.DEFAULT_MAX_INPUT_BYTES_PER_WORKER
@ -99,8 +107,8 @@ public class StageDefinitionTest
Assert.assertThrows( Assert.assertThrows(
ISE.class, ISE.class,
() -> stageDefinition.generatePartitionsForShuffle(ClusterByStatisticsCollectorImpl.create(new ClusterBy( () -> stageDefinition.generatePartitionBoundariesForShuffle(ClusterByStatisticsCollectorImpl.create(new ClusterBy(
ImmutableList.of(new SortColumn("test", false)), ImmutableList.of(new KeyColumn("test", KeyOrder.ASCENDING)),
1 1
), RowSignature.builder().add("test", ColumnType.STRING).build(), 1000, 100, false, false)) ), RowSignature.builder().add("test", ColumnType.STRING).build(), 1000, 100, false, false))
); );

View File

@ -22,13 +22,14 @@ package org.apache.druid.msq.kernel.controller;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.SortColumn; import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.msq.input.InputSpec; import org.apache.druid.msq.input.InputSpec;
import org.apache.druid.msq.input.stage.StageInputSpec; import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.kernel.FrameProcessorFactory; import org.apache.druid.msq.kernel.FrameProcessorFactory;
import org.apache.druid.msq.kernel.MaxCountShuffleSpec; import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.ShuffleSpec; import org.apache.druid.msq.kernel.ShuffleSpec;
@ -115,10 +116,10 @@ public class MockQueryDefinitionBuilder
ShuffleSpec shuffleSpec; ShuffleSpec shuffleSpec;
if (shuffling) { if (shuffling) {
shuffleSpec = new MaxCountShuffleSpec( shuffleSpec = new GlobalSortMaxCountShuffleSpec(
new ClusterBy( new ClusterBy(
ImmutableList.of( ImmutableList.of(
new SortColumn(SHUFFLE_KEY_COLUMN, false) new KeyColumn(SHUFFLE_KEY_COLUMN, KeyOrder.ASCENDING)
), ),
0 0
), ),

View File

@ -22,19 +22,18 @@ package org.apache.druid.msq.querykit.scan;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.MoreExecutors;
import it.unimi.dsi.fastutil.ints.Int2ObjectMaps; import it.unimi.dsi.fastutil.ints.Int2ObjectMaps;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.collections.ResourceHolder; import org.apache.druid.collections.ResourceHolder;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType; import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.ArenaMemoryAllocator; import org.apache.druid.frame.allocation.ArenaMemoryAllocator;
import org.apache.druid.frame.allocation.HeapMemoryAllocator; import org.apache.druid.frame.allocation.HeapMemoryAllocator;
import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
import org.apache.druid.frame.channel.BlockingQueueFrameChannel; import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.processor.FrameProcessorExecutor; import org.apache.druid.frame.processor.FrameProcessorExecutor;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.testutil.FrameSequenceBuilder; import org.apache.druid.frame.testutil.FrameSequenceBuilder;
import org.apache.druid.frame.testutil.FrameTestUtil; import org.apache.druid.frame.testutil.FrameTestUtil;
import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.jackson.DefaultObjectMapper; import org.apache.druid.jackson.DefaultObjectMapper;
@ -46,10 +45,10 @@ import org.apache.druid.msq.input.ReadableInput;
import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.StagePartition; import org.apache.druid.msq.kernel.StagePartition;
import org.apache.druid.msq.querykit.LazyResourceHolder; import org.apache.druid.msq.querykit.LazyResourceHolder;
import org.apache.druid.msq.test.LimitedFrameWriterFactory;
import org.apache.druid.query.Druids; import org.apache.druid.query.Druids;
import org.apache.druid.query.scan.ScanQuery; import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.query.spec.MultipleIntervalSegmentSpec; import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.TestIndex; import org.apache.druid.segment.TestIndex;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter; import org.apache.druid.segment.incremental.IncrementalIndexStorageAdapter;
@ -117,10 +116,10 @@ public class ScanQueryFrameProcessorTest extends InitializedNullHandlingTest
final StagePartition stagePartition = new StagePartition(new StageId("query", 0), 0); final StagePartition stagePartition = new StagePartition(new StageId("query", 0), 0);
// Limit output frames to 1 row to ensure we test edge cases // Limit output frames to 1 row to ensure we test edge cases
final FrameWriterFactory frameWriterFactory = limitedFrameWriterFactory( final FrameWriterFactory frameWriterFactory = new LimitedFrameWriterFactory(
FrameWriters.makeFrameWriterFactory( FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED, FrameType.ROW_BASED,
HeapMemoryAllocator.unlimited(), new SingleMemoryAllocatorFactory(HeapMemoryAllocator.unlimited()),
signature, signature,
Collections.emptyList() Collections.emptyList()
), ),
@ -171,72 +170,4 @@ public class ScanQueryFrameProcessorTest extends InitializedNullHandlingTest
Assert.assertEquals(adapter.getNumRows(), (long) retVal.get()); Assert.assertEquals(adapter.getNumRows(), (long) retVal.get());
} }
/**
* Wraps a {@link FrameWriterFactory}, creating a new factory that returns {@link FrameWriter} which write
* a limited number of rows.
*/
private static FrameWriterFactory limitedFrameWriterFactory(final FrameWriterFactory baseFactory, final int rowLimit)
{
return new FrameWriterFactory()
{
@Override
public FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory)
{
return new LimitedFrameWriter(baseFactory.newFrameWriter(columnSelectorFactory), rowLimit);
}
@Override
public long allocatorCapacity()
{
return baseFactory.allocatorCapacity();
}
};
}
private static class LimitedFrameWriter implements FrameWriter
{
private final FrameWriter baseWriter;
private final int rowLimit;
public LimitedFrameWriter(FrameWriter baseWriter, int rowLimit)
{
this.baseWriter = baseWriter;
this.rowLimit = rowLimit;
}
@Override
public boolean addSelection()
{
if (baseWriter.getNumRows() >= rowLimit) {
return false;
} else {
return baseWriter.addSelection();
}
}
@Override
public int getNumRows()
{
return baseWriter.getNumRows();
}
@Override
public long getTotalSize()
{
return baseWriter.getTotalSize();
}
@Override
public long writeTo(WritableMemory memory, long position)
{
return baseWriter.writeTo(memory, position);
}
@Override
public void close()
{
baseWriter.close();
}
}
} }

View File

@ -17,8 +17,9 @@
* under the License. * under the License.
*/ */
package org.apache.druid.frame.processor; package org.apache.druid.msq.shuffle;
import org.apache.druid.frame.processor.OutputChannelFactoryTest;
import org.apache.druid.storage.local.LocalFileStorageConnector; import org.apache.druid.storage.local.LocalFileStorageConnector;
import org.junit.ClassRule; import org.junit.ClassRule;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;

View File

@ -28,10 +28,11 @@ import com.google.common.math.LongMath;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.KeyTestUtils; import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.RowKeyReader; import org.apache.druid.frame.key.RowKeyReader;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
@ -73,16 +74,20 @@ public class ClusterByStatisticsCollectorImplTest extends InitializedNullHandlin
.build(); .build();
private static final ClusterBy CLUSTER_BY_X = new ClusterBy( private static final ClusterBy CLUSTER_BY_X = new ClusterBy(
ImmutableList.of(new SortColumn("x", false)), ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)),
0 0
); );
private static final ClusterBy CLUSTER_BY_XY_BUCKET_BY_X = new ClusterBy( private static final ClusterBy CLUSTER_BY_XY_BUCKET_BY_X = new ClusterBy(
ImmutableList.of(new SortColumn("x", false), new SortColumn("y", false)), ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING), new KeyColumn("y", KeyOrder.ASCENDING)),
1 1
); );
private static final ClusterBy CLUSTER_BY_XYZ_BUCKET_BY_X = new ClusterBy( private static final ClusterBy CLUSTER_BY_XYZ_BUCKET_BY_X = new ClusterBy(
ImmutableList.of(new SortColumn("x", false), new SortColumn("y", false), new SortColumn("z", false)), ImmutableList.of(
new KeyColumn("x", KeyOrder.ASCENDING),
new KeyColumn("y", KeyOrder.ASCENDING),
new KeyColumn("z", KeyOrder.ASCENDING)
),
1 1
); );

View File

@ -22,9 +22,10 @@ package org.apache.druid.msq.statistics;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling; import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.KeyTestUtils; import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
@ -38,7 +39,7 @@ import java.util.NoSuchElementException;
public class DelegateOrMinKeyCollectorTest public class DelegateOrMinKeyCollectorTest
{ {
private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0); private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0);
private final RowSignature signature = RowSignature.builder().add("x", ColumnType.LONG).build(); private final RowSignature signature = RowSignature.builder().add("x", ColumnType.LONG).build();
private final Comparator<RowKey> comparator = clusterBy.keyComparator(); private final Comparator<RowKey> comparator = clusterBy.keyComparator();

View File

@ -24,8 +24,9 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.hamcrest.MatcherAssert; import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers; import org.hamcrest.Matchers;
@ -40,7 +41,7 @@ import java.util.NoSuchElementException;
public class DistinctKeyCollectorTest public class DistinctKeyCollectorTest
{ {
private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0); private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0);
private final Comparator<RowKey> comparator = clusterBy.keyComparator(); private final Comparator<RowKey> comparator = clusterBy.keyComparator();
private final int numKeys = 500_000; private final int numKeys = 500_000;

View File

@ -23,9 +23,10 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.KeyTestUtils; import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
@ -289,8 +290,8 @@ public class KeyCollectorTestUtils
private static RowKey createSingleLongKey(final long n) private static RowKey createSingleLongKey(final long n)
{ {
final RowSignature signature = RowSignature.builder().add("x", ColumnType.LONG).build(); final RowSignature signature = RowSignature.builder().add("x", ColumnType.LONG).build();
final List<SortColumn> sortColumns = ImmutableList.of(new SortColumn("x", false)); final List<KeyColumn> keyColumns = ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING));
final RowSignature keySignature = KeyTestUtils.createKeySignature(sortColumns, signature); final RowSignature keySignature = KeyTestUtils.createKeySignature(keyColumns, signature);
return KeyTestUtils.createKey(keySignature, n); return KeyTestUtils.createKey(keySignature, n);
} }
} }

View File

@ -24,9 +24,10 @@ import org.apache.druid.common.config.NullHandling;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartition; import org.apache.druid.frame.key.ClusterByPartition;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
import org.apache.druid.frame.key.KeyTestUtils; import org.apache.druid.frame.key.KeyTestUtils;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.key.SortColumn;
import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.Pair;
import org.apache.druid.segment.column.ColumnType; import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature; import org.apache.druid.segment.column.RowSignature;
@ -41,7 +42,7 @@ import java.util.NoSuchElementException;
public class QuantilesSketchKeyCollectorTest public class QuantilesSketchKeyCollectorTest
{ {
private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0); private final ClusterBy clusterBy = new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0);
private final Comparator<RowKey> comparator = clusterBy.keyComparator(); private final Comparator<RowKey> comparator = clusterBy.keyComparator();
private final int numKeys = 500_000; private final int numKeys = 500_000;
@ -167,12 +168,12 @@ public class QuantilesSketchKeyCollectorTest
@Test @Test
public void testAverageKeyLength() public void testAverageKeyLength()
{ {
final QuantilesSketchKeyCollector collector = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector(); final QuantilesSketchKeyCollector collector =
QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
final QuantilesSketchKeyCollector other = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector(); final QuantilesSketchKeyCollector other = QuantilesSketchKeyCollectorFactory.create(clusterBy).newKeyCollector();
RowSignature smallKeySignature = KeyTestUtils.createKeySignature( RowSignature smallKeySignature = KeyTestUtils.createKeySignature(
new ClusterBy(ImmutableList.of(new SortColumn("x", false)), 0).getColumns(), new ClusterBy(ImmutableList.of(new KeyColumn("x", KeyOrder.ASCENDING)), 0).getColumns(),
RowSignature.builder().add("x", ColumnType.LONG).build() RowSignature.builder().add("x", ColumnType.LONG).build()
); );
RowKey smallKey = KeyTestUtils.createKey(smallKeySignature, 1L); RowKey smallKey = KeyTestUtils.createKey(smallKeySignature, 1L);
@ -180,11 +181,12 @@ public class QuantilesSketchKeyCollectorTest
RowSignature largeKeySignature = KeyTestUtils.createKeySignature( RowSignature largeKeySignature = KeyTestUtils.createKeySignature(
new ClusterBy( new ClusterBy(
ImmutableList.of( ImmutableList.of(
new SortColumn("x", false), new KeyColumn("x", KeyOrder.ASCENDING),
new SortColumn("y", false), new KeyColumn("y", KeyOrder.ASCENDING),
new SortColumn("z", false) new KeyColumn("z", KeyOrder.ASCENDING)
), ),
0).getColumns(), 0
).getColumns(),
RowSignature.builder() RowSignature.builder()
.add("x", ColumnType.LONG) .add("x", ColumnType.LONG)
.add("y", ColumnType.LONG) .add("y", ColumnType.LONG)
@ -201,7 +203,11 @@ public class QuantilesSketchKeyCollectorTest
Assert.assertEquals(largeKey.estimatedObjectSizeBytes(), other.getAverageKeyLength(), 0); Assert.assertEquals(largeKey.estimatedObjectSizeBytes(), other.getAverageKeyLength(), 0);
collector.addAll(other); collector.addAll(other);
Assert.assertEquals((smallKey.estimatedObjectSizeBytes() * 3 + largeKey.estimatedObjectSizeBytes() * 5) / 8.0, collector.getAverageKeyLength(), 0); Assert.assertEquals(
(smallKey.estimatedObjectSizeBytes() * 3 + largeKey.estimatedObjectSizeBytes() * 5) / 8.0,
collector.getAverageKeyLength(),
0
);
} }
@Test @Test

View File

@ -39,8 +39,6 @@ import org.junit.Ignore;
*/ */
public class CalciteSelectQueryTestMSQ extends CalciteQueryTest public class CalciteSelectQueryTestMSQ extends CalciteQueryTest
{ {
private MSQTestOverlordServiceClient indexingServiceClient;
private TestGroupByBuffers groupByBuffers; private TestGroupByBuffers groupByBuffers;
@Before @Before
@ -76,9 +74,10 @@ public class CalciteSelectQueryTestMSQ extends CalciteQueryTest
2, 2,
10, 10,
2, 2,
0,
0 0
); );
indexingServiceClient = new MSQTestOverlordServiceClient( final MSQTestOverlordServiceClient indexingServiceClient = new MSQTestOverlordServiceClient(
queryJsonMapper, queryJsonMapper,
injector, injector,
new MSQTestTaskActionClient(queryJsonMapper), new MSQTestTaskActionClient(queryJsonMapper),

View File

@ -0,0 +1,114 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.msq.test;
import org.apache.datasketches.memory.WritableMemory;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.column.RowSignature;
public class LimitedFrameWriterFactory implements FrameWriterFactory
{
private final FrameWriterFactory baseFactory;
private final int rowLimit;
/**
* Wraps a {@link FrameWriterFactory}, creating a new factory that returns {@link FrameWriter} which write
* a limited number of rows.
*/
public LimitedFrameWriterFactory(FrameWriterFactory baseFactory, int rowLimit)
{
this.baseFactory = baseFactory;
this.rowLimit = rowLimit;
}
@Override
public FrameWriter newFrameWriter(ColumnSelectorFactory columnSelectorFactory)
{
return new LimitedFrameWriter(baseFactory.newFrameWriter(columnSelectorFactory), rowLimit);
}
@Override
public long allocatorCapacity()
{
return baseFactory.allocatorCapacity();
}
@Override
public RowSignature signature()
{
return baseFactory.signature();
}
@Override
public FrameType frameType()
{
return baseFactory.frameType();
}
private static class LimitedFrameWriter implements FrameWriter
{
private final FrameWriter baseWriter;
private final int rowLimit;
public LimitedFrameWriter(FrameWriter baseWriter, int rowLimit)
{
this.baseWriter = baseWriter;
this.rowLimit = rowLimit;
}
@Override
public boolean addSelection()
{
if (baseWriter.getNumRows() >= rowLimit) {
return false;
} else {
return baseWriter.addSelection();
}
}
@Override
public int getNumRows()
{
return baseWriter.getNumRows();
}
@Override
public long getTotalSize()
{
return baseWriter.getTotalSize();
}
@Override
public long writeTo(WritableMemory memory, long position)
{
return baseWriter.writeTo(memory, position);
}
@Override
public void close()
{
baseWriter.close();
}
}
}

View File

@ -286,6 +286,7 @@ public class MSQTestBase extends BaseCalciteQueryTest
2, 2,
10, 10,
2, 2,
1,
0 0
) )
); );

View File

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.allocation;
/**
* Creates {@link ArenaMemoryAllocator} on each call to {@link #newAllocator()}.
*/
public class ArenaMemoryAllocatorFactory implements MemoryAllocatorFactory
{
private final int capacity;
public ArenaMemoryAllocatorFactory(final int capacity)
{
this.capacity = capacity;
}
@Override
public MemoryAllocator newAllocator()
{
return ArenaMemoryAllocator.createOnHeap(capacity);
}
@Override
public long allocatorCapacity()
{
return capacity;
}
}

View File

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.allocation;
/**
* Factory for {@link MemoryAllocator}.
*
* Used by {@link org.apache.druid.frame.write.FrameWriters#makeFrameWriterFactory} to create
* {@link org.apache.druid.frame.write.FrameWriterFactory}.
*/
public interface MemoryAllocatorFactory
{
/**
* Returns a new allocator with capacity {@link #allocatorCapacity()}.
*/
MemoryAllocator newAllocator();
/**
* Capacity of allocators returned by {@link #newAllocator()}.
*/
long allocatorCapacity();
}

View File

@ -22,8 +22,8 @@ package org.apache.druid.frame.allocation;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
/** /**
* Reference to a particular region of some {@link Memory}. This is used because it is cheaper to create than * Reference to a particular region of some {@link Memory}. This is used because it is cheaper to reuse this object
* calling {@link Memory#region}. * rather than calling {@link Memory#region} for each row.
* *
* Not immutable. The pointed-to range may change as this object gets reused. * Not immutable. The pointed-to range may change as this object gets reused.
*/ */
@ -39,8 +39,8 @@ public class MemoryRange<T extends Memory>
} }
/** /**
* Returns the underlying memory *without* clipping it to this particular range. Callers must remember to continue * Returns the underlying memory *without* clipping it to this particular range. Callers must apply the offset
* applying the offset given by {@link #start} and capacity given by {@link #length}. * given by {@link #start} and capacity given by {@link #length}.
*/ */
public T memory() public T memory()
{ {

View File

@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.allocation;
import org.apache.druid.java.util.common.ISE;
/**
* Wraps a single {@link MemoryAllocator}.
*
* The same instance is returned on each call to {@link #newAllocator()}, after validating that it is 100% free.
* Calling {@link #newAllocator()} before freeing all previously-allocated memory leads to an IllegalStateException.
*/
public class SingleMemoryAllocatorFactory implements MemoryAllocatorFactory
{
private final MemoryAllocator allocator;
private final long capacity;
public SingleMemoryAllocatorFactory(final MemoryAllocator allocator)
{
this.allocator = allocator;
this.capacity = allocator.capacity();
}
@Override
public MemoryAllocator newAllocator()
{
// Allocators are reused, which allows each call to "newAllocator" to use the same arena (if it's arena-based).
// Just need to validate that it has actually been closed out prior to handing it to someone else.
if (allocator.available() != allocator.capacity()) {
throw new ISE("Allocator in use");
}
return allocator;
}
@Override
public long allocatorCapacity()
{
return capacity;
}
}

View File

@ -172,8 +172,6 @@ public class BlockingQueueFrameChannel
// If this happens, it's a bug, potentially due to incorrectly using this class with multiple writers. // If this happens, it's a bug, potentially due to incorrectly using this class with multiple writers.
throw new ISE("Could not write error to channel"); throw new ISE("Could not write error to channel");
} }
close();
} }
} }
@ -181,8 +179,8 @@ public class BlockingQueueFrameChannel
public void close() public void close()
{ {
synchronized (lock) { synchronized (lock) {
if (isFinished()) { if (isClosed()) {
throw new ISE("Already done"); throw new ISE("Already closed");
} }
if (!queue.offer(END_MARKER)) { if (!queue.offer(END_MARKER)) {
@ -193,6 +191,15 @@ public class BlockingQueueFrameChannel
notifyReader(); notifyReader();
} }
} }
@Override
public boolean isClosed()
{
synchronized (lock) {
final Optional<Either<Throwable, FrameWithPartition>> lastElement = queue.peekLast();
return lastElement != null && END_MARKER.equals(lastElement);
}
}
} }
private class Readable implements ReadableFrameChannel private class Readable implements ReadableFrameChannel

View File

@ -92,9 +92,16 @@ public class ComposingWritableFrameChannel implements WritableFrameChannel
{ {
if (currentIndex < channels.size()) { if (currentIndex < channels.size()) {
channels.get(currentIndex).get().close(); channels.get(currentIndex).get().close();
currentIndex = channels.size();
} }
} }
@Override
public boolean isClosed()
{
return currentIndex == channels.size();
}
@Override @Override
public ListenableFuture<?> writabilityFuture() public ListenableFuture<?> writabilityFuture()
{ {

View File

@ -73,6 +73,11 @@ public interface WritableFrameChannel extends Closeable
@Override @Override
void close() throws IOException; void close() throws IOException;
/**
* Whether {@link #close()} has been called on this channel.
*/
boolean isClosed();
/** /**
* Returns a future that resolves when {@link #write} is able to receive a new frame without blocking or throwing * Returns a future that resolves when {@link #write} is able to receive a new frame without blocking or throwing
* an exception. The future never resolves to an exception. * an exception. The future never resolves to an exception.

View File

@ -32,6 +32,7 @@ import java.io.IOException;
public class WritableFrameFileChannel implements WritableFrameChannel public class WritableFrameFileChannel implements WritableFrameChannel
{ {
private final FrameFileWriter writer; private final FrameFileWriter writer;
private boolean closed;
public WritableFrameFileChannel(final FrameFileWriter writer) public WritableFrameFileChannel(final FrameFileWriter writer)
{ {
@ -55,6 +56,13 @@ public class WritableFrameFileChannel implements WritableFrameChannel
public void close() throws IOException public void close() throws IOException
{ {
writer.close(); writer.close();
closed = true;
}
@Override
public boolean isClosed()
{
return closed;
} }
@Override @Override

View File

@ -83,6 +83,12 @@ public class ComplexFieldReader implements FieldReader
return DimensionSelector.constant(null, extractionFn); return DimensionSelector.constant(null, extractionFn);
} }
@Override
public boolean isNull(Memory memory, long position)
{
return memory.getByte(position) == ComplexFieldWriter.NULL_BYTE;
}
@Override @Override
public boolean isComparable() public boolean isComparable()
{ {

View File

@ -66,6 +66,12 @@ public class DoubleFieldReader implements FieldReader
); );
} }
@Override
public boolean isNull(Memory memory, long position)
{
return memory.getByte(position) == DoubleFieldWriter.NULL_BYTE;
}
@Override @Override
public boolean isComparable() public boolean isComparable()
{ {

View File

@ -52,6 +52,11 @@ public interface FieldReader
@Nullable ExtractionFn extractionFn @Nullable ExtractionFn extractionFn
); );
/**
* Whether the provided memory position points to a null value.
*/
boolean isNull(Memory memory, long position);
/** /**
* Whether this field is comparable. Comparable fields can be compared as unsigned bytes. * Whether this field is comparable. Comparable fields can be compared as unsigned bytes.
*/ */

View File

@ -66,6 +66,12 @@ public class FloatFieldReader implements FieldReader
); );
} }
@Override
public boolean isNull(Memory memory, long position)
{
return memory.getByte(position) == FloatFieldWriter.NULL_BYTE;
}
@Override @Override
public boolean isComparable() public boolean isComparable()
{ {

View File

@ -66,6 +66,12 @@ public class LongFieldReader implements FieldReader
); );
} }
@Override
public boolean isNull(Memory memory, long position)
{
return memory.getByte(position) == LongFieldWriter.NULL_BYTE;
}
@Override @Override
public boolean isComparable() public boolean isComparable()
{ {

View File

@ -87,6 +87,16 @@ public class StringFieldReader implements FieldReader
return new Selector(memory, fieldPointer, extractionFn, false); return new Selector(memory, fieldPointer, extractionFn, false);
} }
@Override
public boolean isNull(Memory memory, long position)
{
final byte nullByte = memory.getByte(position);
assert nullByte == StringFieldWriter.NULL_BYTE || nullByte == StringFieldWriter.NOT_NULL_BYTE;
return nullByte == StringFieldWriter.NULL_BYTE
&& memory.getByte(position + 1) == StringFieldWriter.VALUE_TERMINATOR
&& memory.getByte(position + 2) == StringFieldWriter.ROW_TERMINATOR;
}
@Override @Override
public boolean isComparable() public boolean isComparable()
{ {

View File

@ -23,6 +23,7 @@ import com.google.common.primitives.Ints;
import it.unimi.dsi.fastutil.ints.IntArrayList; import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList; import it.unimi.dsi.fastutil.ints.IntList;
import org.apache.druid.frame.read.FrameReaderUtils; import org.apache.druid.frame.read.FrameReaderUtils;
import org.apache.druid.java.util.common.IAE;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator; import java.util.Comparator;
@ -49,7 +50,7 @@ public class ByteRowKeyComparator implements Comparator<byte[]>
this.ascDescRunLengths = ascDescRunLengths; this.ascDescRunLengths = ascDescRunLengths;
} }
public static ByteRowKeyComparator create(final List<SortColumn> keyColumns) public static ByteRowKeyComparator create(final List<KeyColumn> keyColumns)
{ {
return new ByteRowKeyComparator( return new ByteRowKeyComparator(
computeFirstFieldPosition(keyColumns.size()), computeFirstFieldPosition(keyColumns.size()),
@ -74,18 +75,24 @@ public class ByteRowKeyComparator implements Comparator<byte[]>
* *
* Public so {@link FrameComparisonWidgetImpl} can use it. * Public so {@link FrameComparisonWidgetImpl} can use it.
*/ */
public static int[] computeAscDescRunLengths(final List<SortColumn> sortColumns) public static int[] computeAscDescRunLengths(final List<KeyColumn> keyColumns)
{ {
final IntList ascDescRunLengths = new IntArrayList(4); final IntList ascDescRunLengths = new IntArrayList(4);
boolean descending = false; KeyOrder order = KeyOrder.ASCENDING;
int runLength = 0; int runLength = 0;
for (final SortColumn column : sortColumns) { for (final KeyColumn column : keyColumns) {
if (column.descending() != descending) { if (column.order() == KeyOrder.NONE) {
throw new IAE("Key must be sortable");
}
if (column.order() != order) {
ascDescRunLengths.add(runLength); ascDescRunLengths.add(runLength);
runLength = 0; runLength = 0;
descending = !descending;
// Invert "order".
order = order == KeyOrder.ASCENDING ? KeyOrder.DESCENDING : KeyOrder.ASCENDING;
} }
runLength++; runLength++;

View File

@ -43,12 +43,13 @@ import java.util.Objects;
*/ */
public class ClusterBy public class ClusterBy
{ {
private final List<SortColumn> columns; private final List<KeyColumn> columns;
private final int bucketByCount; private final int bucketByCount;
private final boolean sortable;
@JsonCreator @JsonCreator
public ClusterBy( public ClusterBy(
@JsonProperty("columns") List<SortColumn> columns, @JsonProperty("columns") List<KeyColumn> columns,
@JsonProperty("bucketByCount") int bucketByCount @JsonProperty("bucketByCount") int bucketByCount
) )
{ {
@ -58,6 +59,21 @@ public class ClusterBy
if (bucketByCount < 0 || bucketByCount > columns.size()) { if (bucketByCount < 0 || bucketByCount > columns.size()) {
throw new IAE("Invalid bucketByCount [%d]", bucketByCount); throw new IAE("Invalid bucketByCount [%d]", bucketByCount);
} }
// Key must be 100% sortable or 100% nonsortable. If empty, call it sortable.
boolean sortable = true;
for (int i = 0; i < columns.size(); i++) {
final KeyColumn column = columns.get(i);
if (i == 0) {
sortable = column.order().sortable();
} else if (sortable != column.order().sortable()) {
throw new IAE("Cannot mix sortable and unsortable key columns");
}
}
this.sortable = sortable;
} }
/** /**
@ -72,7 +88,7 @@ public class ClusterBy
* The columns that comprise this key, in order. * The columns that comprise this key, in order.
*/ */
@JsonProperty @JsonProperty
public List<SortColumn> getColumns() public List<KeyColumn> getColumns()
{ {
return columns; return columns;
} }
@ -86,7 +102,7 @@ public class ClusterBy
* *
* Will always be less than, or equal to, the size of {@link #getColumns()}. * Will always be less than, or equal to, the size of {@link #getColumns()}.
* *
* Not relevant when a ClusterBy instance is used as an ordering key rather than a partitioning key. * Only relevant when a ClusterBy instance is used as a partitioning key.
*/ */
@JsonProperty @JsonProperty
@JsonInclude(JsonInclude.Include.NON_DEFAULT) @JsonInclude(JsonInclude.Include.NON_DEFAULT)
@ -95,6 +111,22 @@ public class ClusterBy
return bucketByCount; return bucketByCount;
} }
/**
* Whether this key is empty.
*/
public boolean isEmpty()
{
return columns.isEmpty();
}
/**
* Whether this key is sortable. Empty keys (with no columns) are considered sortable.
*/
public boolean sortable()
{
return sortable;
}
/** /**
* Create a reader for keys for this instance. * Create a reader for keys for this instance.
* *
@ -105,8 +137,8 @@ public class ClusterBy
{ {
final RowSignature.Builder newSignature = RowSignature.builder(); final RowSignature.Builder newSignature = RowSignature.builder();
for (final SortColumn sortColumn : columns) { for (final KeyColumn keyColumn : columns) {
final String columnName = sortColumn.columnName(); final String columnName = keyColumn.columnName();
final ColumnCapabilities capabilities = inspector.getColumnCapabilities(columnName); final ColumnCapabilities capabilities = inspector.getColumnCapabilities(columnName);
final ColumnType columnType = final ColumnType columnType =
Preconditions.checkNotNull(capabilities, "Type for column [%s]", columnName).toColumnType(); Preconditions.checkNotNull(capabilities, "Type for column [%s]", columnName).toColumnType();

View File

@ -34,6 +34,11 @@ public interface FrameComparisonWidget
*/ */
RowKey readKey(int row); RowKey readKey(int row);
/**
* Whether a particular row has a null field in its comparison key.
*/
boolean isPartiallyNullKey(int row);
/** /**
* Compare a specific row of this frame to the provided key. The key must have been created with sortColumns * Compare a specific row of this frame to the provided key. The key must have been created with sortColumns
* that match the ones used to create this widget, or else results are undefined. * that match the ones used to create this widget, or else results are undefined.
@ -42,7 +47,7 @@ public interface FrameComparisonWidget
/** /**
* Compare a specific row of this frame to a specific row of another frame. The other frame must have the same * Compare a specific row of this frame to a specific row of another frame. The other frame must have the same
* signature, or else results are undefined. The other frame may be the same object as this frame; for example, * sort key, or else results are undefined. The other frame may be the same object as this frame; for example,
* this is used by {@link org.apache.druid.frame.write.FrameSort} to sort frames in-place. * this is used by {@link org.apache.druid.frame.write.FrameSort} to sort frames in-place.
*/ */
int compare(int row, FrameComparisonWidget otherWidget, int otherRow); int compare(int row, FrameComparisonWidget otherWidget, int otherRow);

View File

@ -23,10 +23,13 @@ import com.google.common.primitives.Ints;
import org.apache.datasketches.memory.Memory; import org.apache.datasketches.memory.Memory;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType; import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.field.FieldReader;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.read.FrameReaderUtils; import org.apache.druid.frame.read.FrameReaderUtils;
import org.apache.druid.frame.write.FrameWriterUtils; import org.apache.druid.frame.write.FrameWriterUtils;
import org.apache.druid.frame.write.RowBasedFrameWriter; import org.apache.druid.frame.write.RowBasedFrameWriter;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.column.RowSignature;
import java.util.List; import java.util.List;
@ -39,28 +42,30 @@ import java.util.List;
public class FrameComparisonWidgetImpl implements FrameComparisonWidget public class FrameComparisonWidgetImpl implements FrameComparisonWidget
{ {
private final Frame frame; private final Frame frame;
private final FrameReader frameReader; private final RowSignature signature;
private final Memory rowOffsetRegion; private final Memory rowOffsetRegion;
private final Memory dataRegion; private final Memory dataRegion;
private final int keyFieldCount; private final int keyFieldCount;
private final List<FieldReader> keyFieldReaders;
private final long firstFieldPosition; private final long firstFieldPosition;
private final int[] ascDescRunLengths; private final int[] ascDescRunLengths;
private FrameComparisonWidgetImpl( private FrameComparisonWidgetImpl(
final Frame frame, final Frame frame,
final FrameReader frameReader, final RowSignature signature,
final Memory rowOffsetRegion, final Memory rowOffsetRegion,
final Memory dataRegion, final Memory dataRegion,
final int keyFieldCount, final List<FieldReader> keyFieldReaders,
final long firstFieldPosition, final long firstFieldPosition,
final int[] ascDescRunLengths final int[] ascDescRunLengths
) )
{ {
this.frame = frame; this.frame = frame;
this.frameReader = frameReader; this.signature = signature;
this.rowOffsetRegion = rowOffsetRegion; this.rowOffsetRegion = rowOffsetRegion;
this.dataRegion = dataRegion; this.dataRegion = dataRegion;
this.keyFieldCount = keyFieldCount; this.keyFieldCount = keyFieldReaders.size();
this.keyFieldReaders = keyFieldReaders;
this.firstFieldPosition = firstFieldPosition; this.firstFieldPosition = firstFieldPosition;
this.ascDescRunLengths = ascDescRunLengths; this.ascDescRunLengths = ascDescRunLengths;
} }
@ -68,41 +73,46 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
/** /**
* Create a {@link FrameComparisonWidget} for the given frame. * Create a {@link FrameComparisonWidget} for the given frame.
* *
* Only possible for frames of type {@link FrameType#ROW_BASED}. The provided sortColumns must be a prefix * Only possible for frames of type {@link FrameType#ROW_BASED}. The provided keyColumns must be a prefix
* of {@link FrameReader#signature()}. * of {@link FrameReader#signature()}.
* *
* @param frame frame, must be {@link FrameType#ROW_BASED} * @param frame frame, must be {@link FrameType#ROW_BASED}
* @param frameReader reader for the frame * @param signature signature for the frame
* @param sortColumns columns to sort by * @param keyColumns columns to sort by
* @param keyColumnReaders readers for key columns
*/ */
public static FrameComparisonWidgetImpl create( public static FrameComparisonWidgetImpl create(
final Frame frame, final Frame frame,
final FrameReader frameReader, final RowSignature signature,
final List<SortColumn> sortColumns final List<KeyColumn> keyColumns,
final List<FieldReader> keyColumnReaders
) )
{ {
FrameWriterUtils.verifySortColumns(sortColumns, frameReader.signature()); FrameWriterUtils.verifySortColumns(keyColumns, signature);
if (keyColumnReaders.size() != keyColumns.size()) {
throw new ISE("Mismatched lengths for keyColumnReaders and keyColumns");
}
return new FrameComparisonWidgetImpl( return new FrameComparisonWidgetImpl(
FrameType.ROW_BASED.ensureType(frame), FrameType.ROW_BASED.ensureType(frame),
frameReader, signature,
frame.region(RowBasedFrameWriter.ROW_OFFSET_REGION), frame.region(RowBasedFrameWriter.ROW_OFFSET_REGION),
frame.region(RowBasedFrameWriter.ROW_DATA_REGION), frame.region(RowBasedFrameWriter.ROW_DATA_REGION),
sortColumns.size(), keyColumnReaders,
ByteRowKeyComparator.computeFirstFieldPosition(frameReader.signature().size()), ByteRowKeyComparator.computeFirstFieldPosition(signature.size()),
ByteRowKeyComparator.computeAscDescRunLengths(sortColumns) ByteRowKeyComparator.computeAscDescRunLengths(keyColumns)
); );
} }
@Override @Override
public RowKey readKey(int row) public RowKey readKey(int row)
{ {
final int keyFieldPointersEndInRow = keyFieldCount * Integer.BYTES;
if (keyFieldCount == 0) { if (keyFieldCount == 0) {
return RowKey.empty(); return RowKey.empty();
} }
final int keyFieldPointersEndInRow = keyFieldCount * Integer.BYTES;
final long rowPosition = getRowPositionInDataRegion(row); final long rowPosition = getRowPositionInDataRegion(row);
final int keyEndInRow = final int keyEndInRow =
dataRegion.getInt(rowPosition + (long) (keyFieldCount - 1) * Integer.BYTES); dataRegion.getInt(rowPosition + (long) (keyFieldCount - 1) * Integer.BYTES);
@ -110,7 +120,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
final long keyLength = keyEndInRow - firstFieldPosition; final long keyLength = keyEndInRow - firstFieldPosition;
final byte[] keyBytes = new byte[Ints.checkedCast(keyFieldPointersEndInRow + keyEndInRow - firstFieldPosition)]; final byte[] keyBytes = new byte[Ints.checkedCast(keyFieldPointersEndInRow + keyEndInRow - firstFieldPosition)];
final int headerSizeAdjustment = (frameReader.signature().size() - keyFieldCount) * Integer.BYTES; final int headerSizeAdjustment = (signature.size() - keyFieldCount) * Integer.BYTES;
for (int i = 0; i < keyFieldCount; i++) { for (int i = 0; i < keyFieldCount; i++) {
final int fieldEndPosition = dataRegion.getInt(rowPosition + ((long) Integer.BYTES * i)); final int fieldEndPosition = dataRegion.getInt(rowPosition + ((long) Integer.BYTES * i));
final int adjustedFieldEndPosition = fieldEndPosition - headerSizeAdjustment; final int adjustedFieldEndPosition = fieldEndPosition - headerSizeAdjustment;
@ -127,6 +137,28 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
return RowKey.wrap(keyBytes); return RowKey.wrap(keyBytes);
} }
@Override
public boolean isPartiallyNullKey(int row)
{
if (keyFieldCount == 0) {
return false;
}
final long rowPosition = getRowPositionInDataRegion(row);
long keyFieldPosition = rowPosition + (long) signature.size() * Integer.BYTES;
for (int i = 0; i < keyFieldCount; i++) {
final boolean isNull = keyFieldReaders.get(i).isNull(dataRegion, keyFieldPosition);
if (isNull) {
return true;
} else {
keyFieldPosition = rowPosition + dataRegion.getInt(rowPosition + (long) i * Integer.BYTES);
}
}
return false;
}
@Override @Override
public int compare(int row, RowKey key) public int compare(int row, RowKey key)
{ {
@ -187,7 +219,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
final long otherRowPosition = otherWidgetImpl.getRowPositionInDataRegion(otherRow); final long otherRowPosition = otherWidgetImpl.getRowPositionInDataRegion(otherRow);
long comparableBytesStartPositionInRow = firstFieldPosition; long comparableBytesStartPositionInRow = firstFieldPosition;
long otherComparableBytesStartPositionInRow = firstFieldPosition; long otherComparableBytesStartPositionInRow = otherWidgetImpl.firstFieldPosition;
boolean ascending = true; boolean ascending = true;
int field = 0; int field = 0;
@ -240,8 +272,7 @@ public class FrameComparisonWidgetImpl implements FrameComparisonWidget
long getFieldEndPositionInRow(final long rowPosition, final int fieldNumber) long getFieldEndPositionInRow(final long rowPosition, final int fieldNumber)
{ {
assert fieldNumber >= 0 && fieldNumber < frameReader.signature().size(); assert fieldNumber >= 0 && fieldNumber < signature.size();
return dataRegion.getInt(rowPosition + (long) fieldNumber * Integer.BYTES); return dataRegion.getInt(rowPosition + (long) fieldNumber * Integer.BYTES);
} }

View File

@ -20,7 +20,6 @@
package org.apache.druid.frame.key; package org.apache.druid.frame.key;
import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.StringUtils;
@ -28,17 +27,17 @@ import org.apache.druid.java.util.common.StringUtils;
import java.util.Objects; import java.util.Objects;
/** /**
* Represents a component of an order-by key. * Represents a component of a hash or sorting key.
*/ */
public class SortColumn public class KeyColumn
{ {
private final String columnName; private final String columnName;
private final boolean descending; private final KeyOrder order;
@JsonCreator @JsonCreator
public SortColumn( public KeyColumn(
@JsonProperty("columnName") String columnName, @JsonProperty("columnName") String columnName,
@JsonProperty("descending") boolean descending @JsonProperty("order") KeyOrder order
) )
{ {
if (columnName == null || columnName.isEmpty()) { if (columnName == null || columnName.isEmpty()) {
@ -46,7 +45,7 @@ public class SortColumn
} }
this.columnName = columnName; this.columnName = columnName;
this.descending = descending; this.order = order;
} }
@JsonProperty @JsonProperty
@ -56,10 +55,9 @@ public class SortColumn
} }
@JsonProperty @JsonProperty
@JsonInclude(JsonInclude.Include.NON_DEFAULT) public KeyOrder order()
public boolean descending()
{ {
return descending; return order;
} }
@Override @Override
@ -71,19 +69,19 @@ public class SortColumn
if (o == null || getClass() != o.getClass()) { if (o == null || getClass() != o.getClass()) {
return false; return false;
} }
SortColumn that = (SortColumn) o; KeyColumn that = (KeyColumn) o;
return descending == that.descending && Objects.equals(columnName, that.columnName); return order == that.order && Objects.equals(columnName, that.columnName);
} }
@Override @Override
public int hashCode() public int hashCode()
{ {
return Objects.hash(columnName, descending); return Objects.hash(columnName, order);
} }
@Override @Override
public String toString() public String toString()
{ {
return StringUtils.format("%s%s", columnName, descending ? " DESC" : ""); return StringUtils.format("%s%s", columnName, order == KeyOrder.NONE ? "" : " " + order);
} }
} }

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.key;
/**
* Ordering associated with a {@link KeyColumn}.
*/
public enum KeyOrder
{
/**
* Not ordered.
*
* Possible if the key is used only for non-sorting purposes, such as hashing without sorting.
*/
NONE(false),
/**
* Ordered ascending.
*
* Note that sortable key order does not necessarily mean that we are using range-based partitioning. We may be
* using hash-based partitioning along with each partition internally being sorted by a key.
*/
ASCENDING(true),
/**
* Ordered descending.
*
* Note that sortable key order does not necessarily mean that we are using range-based partitioning. We may be
* using hash-based partitioning along with each partition internally being sorted by a key.
*/
DESCENDING(true);
private final boolean sortable;
KeyOrder(boolean sortable)
{
this.sortable = sortable;
}
public boolean sortable()
{
return sortable;
}
}

View File

@ -36,7 +36,7 @@ public class RowKeyComparator implements Comparator<RowKey>
this.byteRowKeyComparatorDelegate = byteRowKeyComparatorDelegate; this.byteRowKeyComparatorDelegate = byteRowKeyComparatorDelegate;
} }
public static RowKeyComparator create(final List<SortColumn> keyColumns) public static RowKeyComparator create(final List<KeyColumn> keyColumns)
{ {
return new RowKeyComparator(ByteRowKeyComparator.create(keyColumns)); return new RowKeyComparator(ByteRowKeyComparator.create(keyColumns));
} }

View File

@ -38,10 +38,10 @@ public class BlockingQueueOutputChannelFactory implements OutputChannelFactory
public OutputChannel openChannel(final int partitionNumber) public OutputChannel openChannel(final int partitionNumber)
{ {
final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal(); final BlockingQueueFrameChannel channel = BlockingQueueFrameChannel.minimal();
return OutputChannel.pair( return OutputChannel.immediatelyReadablePair(
channel.writable(), channel.writable(),
ArenaMemoryAllocator.createOnHeap(frameSize), ArenaMemoryAllocator.createOnHeap(frameSize),
channel::readable, channel.readable(),
partitionNumber partitionNumber
); );
} }

View File

@ -0,0 +1,348 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.frame.processor;
import com.google.common.collect.ImmutableList;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.datasketches.memory.Memory;
import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryRange;
import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.read.FrameReaderUtils;
import org.apache.druid.frame.segment.row.FrameColumnSelectorFactory;
import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.frame.write.FrameWriterUtils;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.LongColumnSelector;
import org.apache.druid.segment.VirtualColumn;
import org.apache.druid.segment.VirtualColumns;
import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnCapabilitiesImpl;
import org.apache.druid.segment.column.ColumnType;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.function.LongSupplier;
import java.util.function.Supplier;
/**
* Processor that hash-partitions rows from any number of input channels, and writes partitioned frames to output
* channels.
*
* Input frames must be {@link FrameType#ROW_BASED}, and input signature must be the same as output signature.
* This processor hashes each row using {@link Memory#xxHash64} with a seed of {@link #HASH_SEED}.
*/
public class FrameChannelHashPartitioner implements FrameProcessor<Long>
{
private static final String PARTITION_COLUMN_NAME =
StringUtils.format("%s_part", FrameWriterUtils.RESERVED_FIELD_PREFIX);
private static final long HASH_SEED = 0;
private final List<ReadableFrameChannel> inputChannels;
private final List<WritableFrameChannel> outputChannels;
private final FrameReader frameReader;
private final int keyFieldCount;
private final FrameWriterFactory frameWriterFactory;
private final IntSet awaitSet;
private Cursor cursor;
private LongSupplier cursorRowPartitionNumberSupplier;
private long rowsWritten;
// Indirection allows FrameWriters to follow "cursor" even when it is replaced with a new instance.
private final MultiColumnSelectorFactory cursorColumnSelectorFactory;
private final FrameWriter[] frameWriters;
public FrameChannelHashPartitioner(
final List<ReadableFrameChannel> inputChannels,
final List<WritableFrameChannel> outputChannels,
final FrameReader frameReader,
final int keyFieldCount,
final FrameWriterFactory frameWriterFactory
)
{
this.inputChannels = inputChannels;
this.outputChannels = outputChannels;
this.frameReader = frameReader;
this.keyFieldCount = keyFieldCount;
this.frameWriterFactory = frameWriterFactory;
this.awaitSet = FrameProcessors.rangeSet(inputChannels.size());
this.frameWriters = new FrameWriter[outputChannels.size()];
this.cursorColumnSelectorFactory = new MultiColumnSelectorFactory(
Collections.singletonList(() -> cursor.getColumnSelectorFactory()),
frameReader.signature()
).withRowMemoryAndSignatureColumns();
if (!frameReader.signature().equals(frameWriterFactory.signature())) {
throw new IAE("Input signature does not match output signature");
}
}
@Override
public List<ReadableFrameChannel> inputChannels()
{
return inputChannels;
}
@Override
public List<WritableFrameChannel> outputChannels()
{
return outputChannels;
}
@Override
public ReturnOrAwait<Long> runIncrementally(final IntSet readableInputs) throws IOException
{
if (cursor == null) {
readNextFrame(readableInputs);
}
if (cursor != null) {
processCursor();
}
if (cursor != null) {
return ReturnOrAwait.runAgain();
} else if (awaitSet.isEmpty()) {
flushFrameWriters();
return ReturnOrAwait.returnObject(rowsWritten);
} else {
return ReturnOrAwait.awaitAny(awaitSet);
}
}
@Override
public void cleanup() throws IOException
{
FrameProcessors.closeAll(inputChannels(), outputChannels(), frameWriters);
}
private void processCursor() throws IOException
{
assert cursor != null;
while (!cursor.isDone()) {
final int partition = (int) cursorRowPartitionNumberSupplier.getAsLong();
final FrameWriter frameWriter = getOrCreateFrameWriter(partition);
if (frameWriter.addSelection()) {
cursor.advance();
} else if (frameWriter.getNumRows() > 0) {
writeFrame(partition);
return;
} else {
throw new FrameRowTooLargeException(frameWriterFactory.allocatorCapacity());
}
}
cursor = null;
cursorRowPartitionNumberSupplier = null;
}
private void readNextFrame(final IntSet readableInputs)
{
if (cursor != null) {
throw new ISE("Already reading a frame");
}
final IntSet readySet = new IntAVLTreeSet(readableInputs);
for (int channelNumber : readableInputs) {
final ReadableFrameChannel channel = inputChannels.get(channelNumber);
if (channel.isFinished()) {
awaitSet.remove(channelNumber);
readySet.remove(channelNumber);
}
}
if (!readySet.isEmpty()) {
// Read a random channel: avoid biasing towards lower-numbered channels.
final int channelNumber = FrameProcessors.selectRandom(readySet);
final ReadableFrameChannel channel = inputChannels.get(channelNumber);
if (!channel.isFinished()) {
// Need row-based frame so we can hash memory directly.
final Frame frame = FrameType.ROW_BASED.ensureType(channel.read());
final HashPartitionVirtualColumn hashPartitionVirtualColumn =
new HashPartitionVirtualColumn(PARTITION_COLUMN_NAME, frameReader, keyFieldCount, outputChannels.size());
cursor = FrameProcessors.makeCursor(
frame,
frameReader,
VirtualColumns.create(Collections.singletonList(hashPartitionVirtualColumn))
);
cursorRowPartitionNumberSupplier =
cursor.getColumnSelectorFactory().makeColumnValueSelector(PARTITION_COLUMN_NAME)::getLong;
}
}
}
private void flushFrameWriters() throws IOException
{
for (int i = 0; i < frameWriters.length; i++) {
if (frameWriters[i] != null) {
writeFrame(i);
}
}
}
private FrameWriter getOrCreateFrameWriter(final int partition)
{
if (frameWriters[partition] == null) {
frameWriters[partition] = frameWriterFactory.newFrameWriter(cursorColumnSelectorFactory);
}
return frameWriters[partition];
}
private void writeFrame(final int partition) throws IOException
{
if (frameWriters[partition] == null || frameWriters[partition].getNumRows() == 0) {
throw new ISE("Nothing to write for partition [%,d]", partition);
}
final Frame frame = Frame.wrap(frameWriters[partition].toByteArray());
outputChannels.get(partition).write(frame);
frameWriters[partition].close();
frameWriters[partition] = null;
rowsWritten += frame.numRows();
}
/**
* Virtual column that provides a hash code of the {@link FrameType#ROW_BASED} frame row that is wrapped in
* the provided {@link ColumnSelectorFactory}, using {@link FrameReaderUtils#makeRowMemorySupplier}.
*/
private static class HashPartitionVirtualColumn implements VirtualColumn
{
private final String name;
private final FrameReader frameReader;
private final int keyFieldCount;
private final int partitionCount;
public HashPartitionVirtualColumn(
final String name,
final FrameReader frameReader,
final int keyFieldCount,
final int partitionCount
)
{
this.name = name;
this.frameReader = frameReader;
this.keyFieldCount = keyFieldCount;
this.partitionCount = partitionCount;
}
@Override
public String getOutputName()
{
return name;
}
@Override
public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec, ColumnSelectorFactory factory)
{
throw new UnsupportedOperationException();
}
@Override
public ColumnValueSelector<?> makeColumnValueSelector(String columnName, ColumnSelectorFactory factory)
{
final Supplier<MemoryRange<Memory>> rowMemorySupplier =
FrameReaderUtils.makeRowMemorySupplier(factory, frameReader.signature());
final int frameFieldCount = frameReader.signature().size();
return new LongColumnSelector()
{
@Override
public long getLong()
{
if (keyFieldCount == 0) {
return 0;
}
final MemoryRange<Memory> rowMemoryRange = rowMemorySupplier.get();
final Memory memory = rowMemoryRange.memory();
final long keyStartPosition = (long) Integer.BYTES * frameFieldCount;
final long keyEndPosition =
memory.getInt(rowMemoryRange.start() + (long) Integer.BYTES * (keyFieldCount - 1));
final int keyLength = (int) (keyEndPosition - keyStartPosition);
final long hash = memory.xxHash64(rowMemoryRange.start() + keyStartPosition, keyLength, HASH_SEED);
return Math.abs(hash % partitionCount);
}
@Override
public boolean isNull()
{
return false;
}
@Override
public void inspectRuntimeShape(RuntimeShapeInspector inspector)
{
// Nothing to do.
}
};
}
@Override
public ColumnCapabilities capabilities(String columnName)
{
return ColumnCapabilitiesImpl.createSimpleNumericColumnCapabilities(ColumnType.LONG).setHasNulls(false);
}
@Override
public List<String> requiredColumns()
{
return ImmutableList.of(
FrameColumnSelectorFactory.ROW_MEMORY_COLUMN,
FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN
);
}
@Override
public boolean usesDotNotation()
{
return false;
}
@Override
public byte[] getCacheKey()
{
throw new UnsupportedOperationException();
}
}
}

View File

@ -27,20 +27,17 @@ import org.apache.druid.frame.Frame;
import org.apache.druid.frame.channel.FrameWithPartition; import org.apache.druid.frame.channel.FrameWithPartition;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.FrameComparisonWidget; import org.apache.druid.frame.key.FrameComparisonWidget;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.RowKey; import org.apache.druid.frame.key.RowKey;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.segment.row.FrameColumnSelectorFactory;
import org.apache.druid.frame.write.FrameWriter; import org.apache.druid.frame.write.FrameWriter;
import org.apache.druid.frame.write.FrameWriterFactory; import org.apache.druid.frame.write.FrameWriterFactory;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.ISE; import org.apache.druid.java.util.common.ISE;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.Cursor; import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import java.io.IOException; import java.io.IOException;
@ -57,7 +54,7 @@ import java.util.function.Supplier;
* Frames from input channels must be {@link org.apache.druid.frame.FrameType#ROW_BASED}. Output frames will * Frames from input channels must be {@link org.apache.druid.frame.FrameType#ROW_BASED}. Output frames will
* be row-based as well. * be row-based as well.
* *
* For unsorted output, use {@link FrameChannelMuxer} instead. * For unsorted output, use {@link FrameChannelMixer} instead.
*/ */
public class FrameChannelMerger implements FrameProcessor<Long> public class FrameChannelMerger implements FrameProcessor<Long>
{ {
@ -66,7 +63,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
private final List<ReadableFrameChannel> inputChannels; private final List<ReadableFrameChannel> inputChannels;
private final WritableFrameChannel outputChannel; private final WritableFrameChannel outputChannel;
private final FrameReader frameReader; private final FrameReader frameReader;
private final ClusterBy clusterBy; private final List<KeyColumn> sortKey;
private final ClusterByPartitions partitions; private final ClusterByPartitions partitions;
private final IntPriorityQueue priorityQueue; private final IntPriorityQueue priorityQueue;
private final FrameWriterFactory frameWriterFactory; private final FrameWriterFactory frameWriterFactory;
@ -83,7 +80,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
final FrameReader frameReader, final FrameReader frameReader,
final WritableFrameChannel outputChannel, final WritableFrameChannel outputChannel,
final FrameWriterFactory frameWriterFactory, final FrameWriterFactory frameWriterFactory,
final ClusterBy clusterBy, final List<KeyColumn> sortKey,
@Nullable final ClusterByPartitions partitions, @Nullable final ClusterByPartitions partitions,
final long rowLimit final long rowLimit
) )
@ -102,11 +99,15 @@ public class FrameChannelMerger implements FrameProcessor<Long>
throw new IAE("Partitions must all abut each other"); throw new IAE("Partitions must all abut each other");
} }
if (!sortKey.stream().allMatch(keyColumn -> keyColumn.order().sortable())) {
throw new IAE("Key is not sortable");
}
this.inputChannels = inputChannels; this.inputChannels = inputChannels;
this.outputChannel = outputChannel; this.outputChannel = outputChannel;
this.frameReader = frameReader; this.frameReader = frameReader;
this.frameWriterFactory = frameWriterFactory; this.frameWriterFactory = frameWriterFactory;
this.clusterBy = clusterBy; this.sortKey = sortKey;
this.partitions = partitionsToUse; this.partitions = partitionsToUse;
this.rowLimit = rowLimit; this.rowLimit = rowLimit;
this.currentFrames = new FramePlus[inputChannels.size()]; this.currentFrames = new FramePlus[inputChannels.size()];
@ -127,18 +128,10 @@ public class FrameChannelMerger implements FrameProcessor<Long>
frameColumnSelectorFactorySuppliers.add(() -> currentFrames[frameNumber].cursor.getColumnSelectorFactory()); frameColumnSelectorFactorySuppliers.add(() -> currentFrames[frameNumber].cursor.getColumnSelectorFactory());
} }
this.mergedColumnSelectorFactory = this.mergedColumnSelectorFactory = new MultiColumnSelectorFactory(
new MultiColumnSelectorFactory( frameColumnSelectorFactorySuppliers,
frameColumnSelectorFactorySuppliers, frameReader.signature()
).withRowMemoryAndSignatureColumns();
// Include ROW_SIGNATURE_COLUMN, ROW_MEMORY_COLUMN to potentially enable direct row memory copying.
// If these columns don't actually exist in the underlying column selector factories, they'll be ignored.
RowSignature.builder()
.addAll(frameReader.signature())
.add(FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN, ColumnType.UNKNOWN_COMPLEX)
.add(FrameColumnSelectorFactory.ROW_MEMORY_COLUMN, ColumnType.UNKNOWN_COMPLEX)
.build()
);
} }
@Override @Override
@ -244,7 +237,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
if (channel.canRead()) { if (channel.canRead()) {
// Read next frame from this channel. // Read next frame from this channel.
final Frame frame = channel.read(); final Frame frame = channel.read();
currentFrames[currentChannel] = new FramePlus(frame, frameReader, clusterBy); currentFrames[currentChannel] = new FramePlus(frame, frameReader, sortKey);
priorityQueue.enqueue(currentChannel); priorityQueue.enqueue(currentChannel);
} else if (channel.isFinished()) { } else if (channel.isFinished()) {
// Done reading this channel. Fall through and continue with other channels. // Done reading this channel. Fall through and continue with other channels.
@ -281,7 +274,7 @@ public class FrameChannelMerger implements FrameProcessor<Long>
if (channel.canRead()) { if (channel.canRead()) {
final Frame frame = channel.read(); final Frame frame = channel.read();
currentFrames[i] = new FramePlus(frame, frameReader, clusterBy); currentFrames[i] = new FramePlus(frame, frameReader, sortKey);
priorityQueue.enqueue(i); priorityQueue.enqueue(i);
} else if (!channel.isFinished()) { } else if (!channel.isFinished()) {
await.add(i); await.add(i);
@ -301,10 +294,10 @@ public class FrameChannelMerger implements FrameProcessor<Long>
private final FrameComparisonWidget comparisonWidget; private final FrameComparisonWidget comparisonWidget;
private int rowNumber; private int rowNumber;
private FramePlus(Frame frame, FrameReader frameReader, ClusterBy clusterBy) private FramePlus(Frame frame, FrameReader frameReader, List<KeyColumn> sortKey)
{ {
this.cursor = FrameProcessors.makeCursor(frame, frameReader); this.cursor = FrameProcessors.makeCursor(frame, frameReader);
this.comparisonWidget = frameReader.makeComparisonWidget(frame, clusterBy.getColumns()); this.comparisonWidget = frameReader.makeComparisonWidget(frame, sortKey);
this.rowNumber = 0; this.rowNumber = 0;
} }

View File

@ -19,8 +19,7 @@
package org.apache.druid.frame.processor; package org.apache.druid.frame.processor;
import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
@ -29,7 +28,6 @@ import org.apache.druid.frame.channel.WritableFrameChannel;
import java.io.IOException; import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
/** /**
* Processor that merges frames from inputChannels into a single outputChannel. No sorting is done: input frames are * Processor that merges frames from inputChannels into a single outputChannel. No sorting is done: input frames are
@ -37,21 +35,22 @@ import java.util.concurrent.ThreadLocalRandom;
* *
* For sorted output, use {@link FrameChannelMerger} instead. * For sorted output, use {@link FrameChannelMerger} instead.
*/ */
public class FrameChannelMuxer implements FrameProcessor<Long> public class FrameChannelMixer implements FrameProcessor<Long>
{ {
private final List<ReadableFrameChannel> inputChannels; private final List<ReadableFrameChannel> inputChannels;
private final WritableFrameChannel outputChannel; private final WritableFrameChannel outputChannel;
private final IntSet remainingChannels = new IntOpenHashSet(); private final IntSet awaitSet;
private long rowsRead = 0L; private long rowsRead = 0L;
public FrameChannelMuxer( public FrameChannelMixer(
final List<ReadableFrameChannel> inputChannels, final List<ReadableFrameChannel> inputChannels,
final WritableFrameChannel outputChannel final WritableFrameChannel outputChannel
) )
{ {
this.inputChannels = inputChannels; this.inputChannels = inputChannels;
this.outputChannel = outputChannel; this.outputChannel = outputChannel;
this.awaitSet = FrameProcessors.rangeSet(inputChannels.size());
} }
@Override @Override
@ -69,39 +68,33 @@ public class FrameChannelMuxer implements FrameProcessor<Long>
@Override @Override
public ReturnOrAwait<Long> runIncrementally(final IntSet readableInputs) throws IOException public ReturnOrAwait<Long> runIncrementally(final IntSet readableInputs) throws IOException
{ {
if (remainingChannels.isEmpty()) { final IntSet readySet = new IntAVLTreeSet(readableInputs);
// First run.
for (int i = 0; i < inputChannels.size(); i++) { for (int channelNumber : readableInputs) {
final ReadableFrameChannel channel = inputChannels.get(i); final ReadableFrameChannel channel = inputChannels.get(channelNumber);
if (!channel.isFinished()) {
remainingChannels.add(i); if (channel.isFinished()) {
} awaitSet.remove(channelNumber);
readySet.remove(channelNumber);
} }
} }
if (!readableInputs.isEmpty()) { if (!readySet.isEmpty()) {
// Avoid biasing towards lower-numbered channels. // Read a random channel: avoid biasing towards lower-numbered channels.
final int channelIdx = ThreadLocalRandom.current().nextInt(readableInputs.size()); final int channelNumber = FrameProcessors.selectRandom(readySet);
final ReadableFrameChannel channel = inputChannels.get(channelNumber);
int i = 0; if (!channel.isFinished()) {
for (IntIterator iterator = readableInputs.iterator(); iterator.hasNext(); i++) { final Frame frame = channel.read();
final int channelNumber = iterator.nextInt(); outputChannel.write(frame);
final ReadableFrameChannel channel = inputChannels.get(channelNumber); rowsRead += frame.numRows();
if (channel.isFinished()) {
remainingChannels.remove(channelNumber);
} else if (i == channelIdx) {
final Frame frame = channel.read();
outputChannel.write(frame);
rowsRead += frame.numRows();
}
} }
} }
if (remainingChannels.isEmpty()) { if (awaitSet.isEmpty()) {
return ReturnOrAwait.returnObject(rowsRead); return ReturnOrAwait.returnObject(rowsRead);
} else { } else {
return ReturnOrAwait.awaitAny(remainingChannels); return ReturnOrAwait.awaitAny(awaitSet);
} }
} }

View File

@ -71,7 +71,7 @@ public interface FrameProcessor<T>
* *
* Implementations typically call {@link ReadableFrameChannel#close()} and * Implementations typically call {@link ReadableFrameChannel#close()} and
* {@link WritableFrameChannel#close()} on all input and output channels, as well as releasing any additional * {@link WritableFrameChannel#close()} on all input and output channels, as well as releasing any additional
* resources that may be held. * resources that may be held, such as {@link org.apache.druid.frame.write.FrameWriter}.
* *
* In cases of cancellation, this method may be called even if {@link #runIncrementally} has not yet returned a * In cases of cancellation, this method may be called even if {@link #runIncrementally} has not yet returned a
* result via {@link ReturnOrAwait#returnObject}. * result via {@link ReturnOrAwait#returnObject}.

View File

@ -20,23 +20,28 @@
package org.apache.druid.frame.processor; package org.apache.druid.frame.processor;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.ints.IntAVLTreeSet;
import it.unimi.dsi.fastutil.ints.IntIterator;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSortedSet;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.channel.ReadableFrameChannel; import org.apache.druid.frame.channel.ReadableFrameChannel;
import org.apache.druid.frame.channel.WritableFrameChannel; import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.segment.FrameCursor;
import org.apache.druid.frame.segment.FrameStorageAdapter; import org.apache.druid.frame.segment.FrameStorageAdapter;
import org.apache.druid.java.util.common.IAE;
import org.apache.druid.java.util.common.Intervals; import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Yielders; import org.apache.druid.java.util.common.guava.Yielders;
import org.apache.druid.java.util.common.io.Closer; import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.segment.Cursor;
import org.apache.druid.segment.VirtualColumns; import org.apache.druid.segment.VirtualColumns;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
public class FrameProcessors public class FrameProcessors
@ -92,17 +97,64 @@ public class FrameProcessors
return new FrameProcessorWithBaggage(); return new FrameProcessorWithBaggage();
} }
public static Cursor makeCursor(final Frame frame, final FrameReader frameReader) /**
* Returns a {@link FrameCursor} for the provided {@link Frame}, allowing both sequential and random access.
*/
public static FrameCursor makeCursor(final Frame frame, final FrameReader frameReader)
{ {
// Safe to never close the Sequence that the Cursor comes from, because it does not do anything when it is closed. return makeCursor(frame, frameReader, VirtualColumns.EMPTY);
}
/**
* Returns a {@link FrameCursor} for the provided {@link Frame} and {@link VirtualColumns}, allowing both sequential
* and random access.
*/
public static FrameCursor makeCursor(
final Frame frame,
final FrameReader frameReader,
final VirtualColumns virtualColumns
)
{
// Safe to never close the Sequence that the FrameCursor comes from, because it does not need to be closed.
// Refer to FrameStorageAdapter#makeCursors. // Refer to FrameStorageAdapter#makeCursors.
return Yielders.each( return (FrameCursor) Yielders.each(
new FrameStorageAdapter(frame, frameReader, Intervals.ETERNITY) new FrameStorageAdapter(frame, frameReader, Intervals.ETERNITY)
.makeCursors(null, Intervals.ETERNITY, VirtualColumns.EMPTY, Granularities.ALL, false, null) .makeCursors(null, Intervals.ETERNITY, virtualColumns, Granularities.ALL, false, null)
).get(); ).get();
} }
/**
* Creates a mutable sorted set from 0 to "size" (exclusive).
*
* @throws IllegalArgumentException if size is negative
*/
public static IntSortedSet rangeSet(final int size)
{
if (size < 0) {
throw new IAE("Size must be nonnegative");
}
final IntSortedSet set = new IntAVLTreeSet();
for (int i = 0; i < size; i++) {
set.add(i);
}
return set;
}
/**
* Selects a random element from a set of ints.
*/
public static int selectRandom(final IntSet ints)
{
final int idx = ThreadLocalRandom.current().nextInt(ints.size());
final IntIterator iterator = ints.iterator();
iterator.skip(idx);
return iterator.nextInt();
}
/** /**
* Helper method for implementing {@link FrameProcessor#cleanup()}. * Helper method for implementing {@link FrameProcessor#cleanup()}.
* *

View File

@ -20,16 +20,18 @@
package org.apache.druid.frame.processor; package org.apache.druid.frame.processor;
import com.google.common.base.Predicate; import com.google.common.base.Predicate;
import org.apache.druid.frame.segment.row.FrameColumnSelectorFactory;
import org.apache.druid.query.dimension.DimensionSpec; import org.apache.druid.query.dimension.DimensionSpec;
import org.apache.druid.query.filter.ValueMatcher; import org.apache.druid.query.filter.ValueMatcher;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector; import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.ColumnInspector;
import org.apache.druid.segment.ColumnSelectorFactory; import org.apache.druid.segment.ColumnSelectorFactory;
import org.apache.druid.segment.ColumnValueSelector; import org.apache.druid.segment.ColumnValueSelector;
import org.apache.druid.segment.DimensionSelector; import org.apache.druid.segment.DimensionSelector;
import org.apache.druid.segment.DimensionSelectorUtils; import org.apache.druid.segment.DimensionSelectorUtils;
import org.apache.druid.segment.IdLookup; import org.apache.druid.segment.IdLookup;
import org.apache.druid.segment.column.ColumnCapabilities; import org.apache.druid.segment.column.ColumnCapabilities;
import org.apache.druid.segment.column.ColumnType;
import org.apache.druid.segment.column.RowSignature;
import org.apache.druid.segment.data.IndexedInts; import org.apache.druid.segment.data.IndexedInts;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -44,17 +46,17 @@ import java.util.function.Supplier;
public class MultiColumnSelectorFactory implements ColumnSelectorFactory public class MultiColumnSelectorFactory implements ColumnSelectorFactory
{ {
private final List<Supplier<ColumnSelectorFactory>> factorySuppliers; private final List<Supplier<ColumnSelectorFactory>> factorySuppliers;
private final ColumnInspector columnInspector; private final RowSignature signature;
private int currentFactory = 0; private int currentFactory = 0;
public MultiColumnSelectorFactory( public MultiColumnSelectorFactory(
final List<Supplier<ColumnSelectorFactory>> factorySuppliers, final List<Supplier<ColumnSelectorFactory>> factorySuppliers,
final ColumnInspector columnInspector final RowSignature signature
) )
{ {
this.factorySuppliers = factorySuppliers; this.factorySuppliers = factorySuppliers;
this.columnInspector = columnInspector; this.signature = signature;
} }
public void setCurrentFactory(final int currentFactory) public void setCurrentFactory(final int currentFactory)
@ -62,6 +64,24 @@ public class MultiColumnSelectorFactory implements ColumnSelectorFactory
this.currentFactory = currentFactory; this.currentFactory = currentFactory;
} }
/**
* Create a copy that includes {@link FrameColumnSelectorFactory#ROW_SIGNATURE_COLUMN} and
* {@link FrameColumnSelectorFactory#ROW_MEMORY_COLUMN} to potentially enable direct row memory copying. If these
* columns don't actually exist in the underlying column selector factories, they'll be ignored, so it's OK to
* use this method even if the columns may not exist.
*/
public MultiColumnSelectorFactory withRowMemoryAndSignatureColumns()
{
return new MultiColumnSelectorFactory(
factorySuppliers,
RowSignature.builder()
.addAll(signature)
.add(FrameColumnSelectorFactory.ROW_SIGNATURE_COLUMN, ColumnType.UNKNOWN_COMPLEX)
.add(FrameColumnSelectorFactory.ROW_MEMORY_COLUMN, ColumnType.UNKNOWN_COMPLEX)
.build()
);
}
@Override @Override
public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec) public DimensionSelector makeDimensionSelector(DimensionSpec dimensionSpec)
{ {
@ -235,6 +255,6 @@ public class MultiColumnSelectorFactory implements ColumnSelectorFactory
@Override @Override
public ColumnCapabilities getColumnCapabilities(String column) public ColumnCapabilities getColumnCapabilities(String column)
{ {
return columnInspector.getColumnCapabilities(column); return signature.getColumnCapabilities(column);
} }
} }

View File

@ -47,18 +47,21 @@ public class OutputChannel
@Nullable @Nullable
private final MemoryAllocator frameMemoryAllocator; private final MemoryAllocator frameMemoryAllocator;
private final Supplier<ReadableFrameChannel> readableChannelSupplier; private final Supplier<ReadableFrameChannel> readableChannelSupplier;
private final boolean readableChannelUsableWhileWriting;
private final int partitionNumber; private final int partitionNumber;
private OutputChannel( private OutputChannel(
@Nullable final WritableFrameChannel writableChannel, @Nullable final WritableFrameChannel writableChannel,
@Nullable final MemoryAllocator frameMemoryAllocator, @Nullable final MemoryAllocator frameMemoryAllocator,
final Supplier<ReadableFrameChannel> readableChannelSupplier, final Supplier<ReadableFrameChannel> readableChannelSupplier,
final boolean readableChannelUsableWhileWriting,
final int partitionNumber final int partitionNumber
) )
{ {
this.writableChannel = writableChannel; this.writableChannel = writableChannel;
this.frameMemoryAllocator = frameMemoryAllocator; this.frameMemoryAllocator = frameMemoryAllocator;
this.readableChannelSupplier = readableChannelSupplier; this.readableChannelSupplier = readableChannelSupplier;
this.readableChannelUsableWhileWriting = readableChannelUsableWhileWriting;
this.partitionNumber = partitionNumber; this.partitionNumber = partitionNumber;
if (partitionNumber < 0 && partitionNumber != FrameWithPartition.NO_PARTITION) { if (partitionNumber < 0 && partitionNumber != FrameWithPartition.NO_PARTITION) {
@ -67,7 +70,7 @@ public class OutputChannel
} }
/** /**
* Creates an output channel pair. * Creates an output channel pair, where the readable channel is not usable until writing is complete.
* *
* @param writableChannel writable channel for producer * @param writableChannel writable channel for producer
* @param frameMemoryAllocator memory allocator for producer to use while writing frames to the channel * @param frameMemoryAllocator memory allocator for producer to use while writing frames to the channel
@ -86,17 +89,71 @@ public class OutputChannel
Preconditions.checkNotNull(writableChannel, "writableChannel"), Preconditions.checkNotNull(writableChannel, "writableChannel"),
Preconditions.checkNotNull(frameMemoryAllocator, "frameMemoryAllocator"), Preconditions.checkNotNull(frameMemoryAllocator, "frameMemoryAllocator"),
readableChannelSupplier, readableChannelSupplier,
false,
partitionNumber partitionNumber
); );
} }
/**
* Creates an output channel pair, where the readable channel is usable before writing is complete.
*
* @param writableChannel writable channel for producer
* @param frameMemoryAllocator memory allocator for producer to use while writing frames to the channel
* @param readableChannel readable channel for consumer
* @param partitionNumber partition number, if any; may be {@link FrameWithPartition#NO_PARTITION} if unknown
*/
public static OutputChannel immediatelyReadablePair(
final WritableFrameChannel writableChannel,
final MemoryAllocator frameMemoryAllocator,
final ReadableFrameChannel readableChannel,
final int partitionNumber
)
{
return new OutputChannel(
Preconditions.checkNotNull(writableChannel, "writableChannel"),
Preconditions.checkNotNull(frameMemoryAllocator, "frameMemoryAllocator"),
() -> readableChannel,
true,
partitionNumber
);
}
/**
* Creates a read-only output channel.
*
* @param readableChannel readable channel for consumer.
* @param partitionNumber partition number, if any; may be {@link FrameWithPartition#NO_PARTITION} if unknown
*/
public static OutputChannel readOnly(
final ReadableFrameChannel readableChannel,
final int partitionNumber
)
{
return readOnly(() -> readableChannel, partitionNumber);
}
/**
* Creates a read-only output channel.
*
* @param readableChannelSupplier readable channel for consumer. May be called multiple times, so you should wrap this
* in {@link Suppliers#memoize} if needed.
* @param partitionNumber partition number, if any; may be {@link FrameWithPartition#NO_PARTITION} if unknown
*/
public static OutputChannel readOnly(
final Supplier<ReadableFrameChannel> readableChannelSupplier,
final int partitionNumber
)
{
return new OutputChannel(null, null, readableChannelSupplier, true, partitionNumber);
}
/** /**
* Create a nil output channel, representing a processor that writes nothing. It is not actually writable, but * Create a nil output channel, representing a processor that writes nothing. It is not actually writable, but
* provides a way for downstream processors to read nothing. * provides a way for downstream processors to read nothing.
*/ */
public static OutputChannel nil(final int partitionNumber) public static OutputChannel nil(final int partitionNumber)
{ {
return new OutputChannel(null, null, () -> ReadableNilFrameChannel.INSTANCE, partitionNumber); return new OutputChannel(null, null, () -> ReadableNilFrameChannel.INSTANCE, true, partitionNumber);
} }
/** /**
@ -126,10 +183,23 @@ public class OutputChannel
/** /**
* Returns the readable channel of this pair. This readable channel may, or may not, be usable before the * Returns the readable channel of this pair. This readable channel may, or may not, be usable before the
* writable channel is closed. It depends on whether the channel pair was created in a stream-capable manner or not. * writable channel is closed. It depends on whether the channel pair was created in a stream-capable manner or not.
* Check {@link #isReadableChannelReady()} to find out.
*/ */
public ReadableFrameChannel getReadableChannel() public ReadableFrameChannel getReadableChannel()
{ {
return readableChannelSupplier.get(); if (isReadableChannelReady()) {
return readableChannelSupplier.get();
} else {
throw new ISE("Readable channel is not ready");
}
}
/**
* Whether {@link #getReadableChannel()} is ready to use.
*/
public boolean isReadableChannelReady()
{
return readableChannelUsableWhileWriting || writableChannel == null || writableChannel.isClosed();
} }
public Supplier<ReadableFrameChannel> getReadableChannelSupplier() public Supplier<ReadableFrameChannel> getReadableChannelSupplier()
@ -151,6 +221,7 @@ public class OutputChannel
mapFn.apply(writableChannel), mapFn.apply(writableChannel),
frameMemoryAllocator, frameMemoryAllocator,
readableChannelSupplier, readableChannelSupplier,
readableChannelUsableWhileWriting,
partitionNumber partitionNumber
); );
} }
@ -162,6 +233,6 @@ public class OutputChannel
*/ */
public OutputChannel readOnly() public OutputChannel readOnly()
{ {
return new OutputChannel(null, null, readableChannelSupplier, partitionNumber); return OutputChannel.readOnly(readableChannelSupplier, partitionNumber);
} }
} }

View File

@ -22,6 +22,7 @@ package org.apache.druid.frame.processor;
import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap; import it.unimi.dsi.fastutil.ints.Int2ObjectRBTreeMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap; import it.unimi.dsi.fastutil.ints.Int2ObjectSortedMap;
import it.unimi.dsi.fastutil.ints.IntSortedSet; import it.unimi.dsi.fastutil.ints.IntSortedSet;
import org.apache.druid.frame.channel.ReadableFrameChannel;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
@ -89,6 +90,14 @@ public class OutputChannels
return outputChannels; return outputChannels;
} }
/**
* Returns all channels, as readable channels.
*/
public List<ReadableFrameChannel> getAllReadableChannels()
{
return outputChannels.stream().map(OutputChannel::getReadableChannel).collect(Collectors.toList());
}
/** /**
* Returns channels for which {@link OutputChannel#getPartitionNumber()} returns {@code partitionNumber}. * Returns channels for which {@link OutputChannel#getPartitionNumber()} returns {@code partitionNumber}.
*/ */
@ -111,4 +120,15 @@ public class OutputChannels
{ {
return wrapReadOnly(outputChannels); return wrapReadOnly(outputChannels);
} }
public boolean areReadableChannelsReady()
{
for (final OutputChannel outputChannel : outputChannels) {
if (!outputChannel.isReadableChannelReady()) {
return false;
}
}
return true;
}
} }

View File

@ -19,7 +19,6 @@
package org.apache.druid.frame.processor; package org.apache.druid.frame.processor;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.ints.IntSet; import it.unimi.dsi.fastutil.ints.IntSet;
import it.unimi.dsi.fastutil.ints.IntSets; import it.unimi.dsi.fastutil.ints.IntSets;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
@ -203,13 +202,7 @@ public class ReturnOrAwait<T>
} else if (count == 1) { } else if (count == 1) {
return RANGE_SET_ONE; return RANGE_SET_ONE;
} else { } else {
final IntSet retVal = new IntOpenHashSet(); return IntSets.fromTo(0, count);
for (int i = 0; i < count; i++) {
retVal.add(i);
}
return retVal;
} }
} }
} }

View File

@ -38,7 +38,8 @@ import it.unimi.dsi.fastutil.longs.LongSortedSet;
import org.apache.druid.common.guava.FutureUtils; import org.apache.druid.common.guava.FutureUtils;
import org.apache.druid.frame.Frame; import org.apache.druid.frame.Frame;
import org.apache.druid.frame.FrameType; import org.apache.druid.frame.FrameType;
import org.apache.druid.frame.allocation.MemoryAllocator; import org.apache.druid.frame.allocation.MemoryAllocatorFactory;
import org.apache.druid.frame.allocation.SingleMemoryAllocatorFactory;
import org.apache.druid.frame.channel.BlockingQueueFrameChannel; import org.apache.druid.frame.channel.BlockingQueueFrameChannel;
import org.apache.druid.frame.channel.FrameWithPartition; import org.apache.druid.frame.channel.FrameWithPartition;
import org.apache.druid.frame.channel.PartitionedReadableFrameChannel; import org.apache.druid.frame.channel.PartitionedReadableFrameChannel;
@ -47,6 +48,7 @@ import org.apache.druid.frame.channel.WritableFrameChannel;
import org.apache.druid.frame.file.FrameFile; import org.apache.druid.frame.file.FrameFile;
import org.apache.druid.frame.key.ClusterBy; import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.ClusterByPartitions; import org.apache.druid.frame.key.ClusterByPartitions;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.read.FrameReader; import org.apache.druid.frame.read.FrameReader;
import org.apache.druid.frame.write.FrameWriters; import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.IAE;
@ -118,7 +120,7 @@ public class SuperSorter
private final List<ReadableFrameChannel> inputChannels; private final List<ReadableFrameChannel> inputChannels;
private final FrameReader frameReader; private final FrameReader frameReader;
private final ClusterBy clusterBy; private final List<KeyColumn> sortKey;
private final ListenableFuture<ClusterByPartitions> outputPartitionsFuture; private final ListenableFuture<ClusterByPartitions> outputPartitionsFuture;
private final FrameProcessorExecutor exec; private final FrameProcessorExecutor exec;
private final OutputChannelFactory outputChannelFactory; private final OutputChannelFactory outputChannelFactory;
@ -185,7 +187,7 @@ public class SuperSorter
* {@link ClusterBy#getColumns()}, or else sorting will not produce correct * {@link ClusterBy#getColumns()}, or else sorting will not produce correct
* output. * output.
* @param frameReader frame reader for the input channels * @param frameReader frame reader for the input channels
* @param clusterBy desired sorting order * @param sortKey desired sorting order
* @param outputPartitionsFuture a future that resolves to the desired output partitions. Sorting will block * @param outputPartitionsFuture a future that resolves to the desired output partitions. Sorting will block
* prior to writing out final outputs until this future resolves. However, the * prior to writing out final outputs until this future resolves. However, the
* sorter will be able to read all inputs even if this future is unresolved. * sorter will be able to read all inputs even if this future is unresolved.
@ -208,7 +210,7 @@ public class SuperSorter
public SuperSorter( public SuperSorter(
final List<ReadableFrameChannel> inputChannels, final List<ReadableFrameChannel> inputChannels,
final FrameReader frameReader, final FrameReader frameReader,
final ClusterBy clusterBy, final List<KeyColumn> sortKey,
final ListenableFuture<ClusterByPartitions> outputPartitionsFuture, final ListenableFuture<ClusterByPartitions> outputPartitionsFuture,
final FrameProcessorExecutor exec, final FrameProcessorExecutor exec,
final OutputChannelFactory outputChannelFactory, final OutputChannelFactory outputChannelFactory,
@ -222,7 +224,7 @@ public class SuperSorter
{ {
this.inputChannels = inputChannels; this.inputChannels = inputChannels;
this.frameReader = frameReader; this.frameReader = frameReader;
this.clusterBy = clusterBy; this.sortKey = sortKey;
this.outputPartitionsFuture = outputPartitionsFuture; this.outputPartitionsFuture = outputPartitionsFuture;
this.exec = exec; this.exec = exec;
this.outputChannelFactory = outputChannelFactory; this.outputChannelFactory = outputChannelFactory;
@ -593,22 +595,22 @@ public class SuperSorter
{ {
try { try {
final WritableFrameChannel writableChannel; final WritableFrameChannel writableChannel;
final MemoryAllocator frameAllocator; final MemoryAllocatorFactory frameAllocatorFactory;
String levelAndRankKey = mergerOutputFileName(level, rank); String levelAndRankKey = mergerOutputFileName(level, rank);
if (totalMergingLevels != UNKNOWN_LEVEL && level == totalMergingLevels - 1) { if (totalMergingLevels != UNKNOWN_LEVEL && level == totalMergingLevels - 1) {
final int intRank = Ints.checkedCast(rank); final int intRank = Ints.checkedCast(rank);
final OutputChannel outputChannel = outputChannelFactory.openChannel(intRank); final OutputChannel outputChannel = outputChannelFactory.openChannel(intRank);
outputChannels.set(intRank, outputChannel.readOnly()); outputChannels.set(intRank, outputChannel.readOnly());
frameAllocatorFactory = new SingleMemoryAllocatorFactory(outputChannel.getFrameMemoryAllocator());
writableChannel = outputChannel.getWritableChannel(); writableChannel = outputChannel.getWritableChannel();
frameAllocator = outputChannel.getFrameMemoryAllocator();
} else { } else {
PartitionedOutputChannel partitionedOutputChannel = intermediateOutputChannelFactory.openPartitionedChannel( PartitionedOutputChannel partitionedOutputChannel = intermediateOutputChannelFactory.openPartitionedChannel(
levelAndRankKey, levelAndRankKey,
true true
); );
writableChannel = partitionedOutputChannel.getWritableChannel(); writableChannel = partitionedOutputChannel.getWritableChannel();
frameAllocator = partitionedOutputChannel.getFrameMemoryAllocator(); frameAllocatorFactory = new SingleMemoryAllocatorFactory(partitionedOutputChannel.getFrameMemoryAllocator());
levelAndRankToReadableChannelMap.put(levelAndRankKey, partitionedOutputChannel); levelAndRankToReadableChannelMap.put(levelAndRankKey, partitionedOutputChannel);
} }
@ -619,12 +621,12 @@ public class SuperSorter
writableChannel, writableChannel,
FrameWriters.makeFrameWriterFactory( FrameWriters.makeFrameWriterFactory(
FrameType.ROW_BASED, // Row-based frames are generally preferred as inputs to mergers FrameType.ROW_BASED, // Row-based frames are generally preferred as inputs to mergers
frameAllocator, frameAllocatorFactory,
frameReader.signature(), frameReader.signature(),
// No sortColumns, because FrameChannelMerger generates frames that are sorted all on its own // No sortColumns, because FrameChannelMerger generates frames that are sorted all on its own
Collections.emptyList() Collections.emptyList()
), ),
clusterBy, sortKey,
partitions, partitions,
rowLimit rowLimit
); );
@ -830,7 +832,6 @@ public class SuperSorter
return StringUtils.format("merged.%d.%d", level, rank); return StringUtils.format("merged.%d.%d", level, rank);
} }
/** /**
* Returns a string encapsulating the current state of this object. * Returns a string encapsulating the current state of this object.
*/ */

Some files were not shown because too many files have changed in this diff Show More