diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 6abb2f5c6d3c..7397bf374a84 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -761,6 +761,16 @@ config_namespace! { /// using the provided `target_partitions` level pub repartition_joins: bool, default = true + /// Should DataFusion use spillable partitioned hash joins instead of regular partitioned joins + /// when repartitioning is enabled. This allows handling larger datasets by spilling to disk + /// when memory pressure occurs during join execution. + pub enable_spillable_hash_join: bool, default = true + + /// When set to true, spillable partitioned hash joins will be replaced with the experimental + /// Grace hash join operator which repartitions both inputs to disk before performing the join. + /// This trades additional IO for predictable memory usage on very large joins. + pub enable_grace_hash_join: bool, default = false + /// Should DataFusion allow symmetric hash joins for unbounded data sources even when /// its inputs do not have any ordering or filtering If the flag is not enabled, /// the SymmetricHashJoin operator will be unable to prune its internal buffers, diff --git a/datafusion/core/tests/physical_optimizer/join_selection.rs b/datafusion/core/tests/physical_optimizer/join_selection.rs index 7ae1d6e50dc3..a68e465c8568 100644 --- a/datafusion/core/tests/physical_optimizer/join_selection.rs +++ b/datafusion/core/tests/physical_optimizer/join_selection.rs @@ -40,7 +40,9 @@ use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::displayable; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::utils::JoinFilter; -use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode}; +use datafusion_physical_plan::joins::{ + GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, +}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::ExecutionPlanProperties; use datafusion_physical_plan::{ @@ -266,6 +268,76 @@ async fn test_join_with_swap() { ); } +#[tokio::test] +async fn test_grace_hash_join_enabled() { + let (big, small) = create_big_and_small(); + let join = Arc::new( + HashJoinExec::try_new( + Arc::clone(&small), + Arc::clone(&big), + vec![( + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + )], + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_grace_hash_join = true; + config.optimizer.enable_spillable_hash_join = true; + config.optimizer.hash_join_single_partition_threshold = 1; + config.optimizer.hash_join_single_partition_threshold_rows = 1; + + let optimized = JoinSelection::new().optimize(join, &config).unwrap(); + assert!( + optimized.as_any().is::(), + "expected GraceHashJoinExec when grace hash join is enabled" + ); +} + +#[tokio::test] +async fn test_grace_hash_join_disabled() { + let (big, small) = create_big_and_small(); + let join = Arc::new( + HashJoinExec::try_new( + Arc::clone(&small), + Arc::clone(&big), + vec![( + Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()), + Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()), + )], + None, + &JoinType::Inner, + None, + PartitionMode::Auto, + NullEquality::NullEqualsNothing, + ) + .unwrap(), + ); + + let mut config = ConfigOptions::new(); + config.optimizer.enable_grace_hash_join = false; + config.optimizer.enable_spillable_hash_join = true; + config.optimizer.hash_join_single_partition_threshold = 1; + config.optimizer.hash_join_single_partition_threshold_rows = 1; + + let optimized = JoinSelection::new().optimize(join, &config).unwrap(); + let hash_join = optimized + .as_any() + .downcast_ref::() + .expect("Grace disabled should keep HashJoinExec"); + assert_eq!( + hash_join.partition_mode(), + &PartitionMode::PartitionedSpillable + ); +} + #[tokio::test] async fn test_left_join_no_swap() { let (big, small) = create_big_and_small(); diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 491b1aca69ea..47f2fc43236e 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -235,6 +235,16 @@ impl SessionConfig { self.options.optimizer.repartition_joins } + /// Are spillable partitioned hash joins enabled? + pub fn enable_spillable_hash_join(&self) -> bool { + self.options.optimizer.enable_spillable_hash_join + } + + /// Should spillable hash joins be executed via the Grace hash join operator? + pub fn enable_grace_hash_join(&self) -> bool { + self.options.optimizer.enable_grace_hash_join + } + /// Are aggregates repartitioned during execution? pub fn repartition_aggregations(&self) -> bool { self.options.optimizer.repartition_aggregations @@ -298,6 +308,18 @@ impl SessionConfig { self } + /// Enables or disables spillable partitioned hash joins for handling larger datasets + pub fn with_enable_spillable_hash_join(mut self, enabled: bool) -> Self { + self.options_mut().optimizer.enable_spillable_hash_join = enabled; + self + } + + /// Enables or disables the Grace hash join operator for spillable hash joins + pub fn with_enable_grace_hash_join(mut self, enabled: bool) -> Self { + self.options_mut().optimizer.enable_grace_hash_join = enabled; + self + } + /// Enables or disables the use of repartitioning for aggregations to improve parallelism pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { self.options_mut().optimizer.repartition_aggregations = enabled; diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 82f2d75ac1b5..67251088af12 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -26,7 +26,7 @@ use rand::{rng, Rng}; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use tempfile::{Builder, NamedTempFile, TempDir}; +use tempfile::{Builder, NamedTempFile, TempDir, TempPath}; use crate::memory_pool::human_readable_size; @@ -370,6 +370,17 @@ impl RefCountedTempFile { pub fn current_disk_usage(&self) -> u64 { self.current_file_disk_usage } + + pub fn clone_refcounted(&self) -> Result { + let reopened = std::fs::File::open(self.path())?; + let temp_path = TempPath::from_path(self.path()); + Ok(Self { + _parent_temp_dir: Arc::clone(&self._parent_temp_dir), + tempfile: NamedTempFile::from_parts(reopened, temp_path), + current_file_disk_usage: self.current_file_disk_usage, + disk_manager: Arc::clone(&self.disk_manager), + }) + } } /// When the temporary file is dropped, subtract its disk usage from the disk manager's total diff --git a/datafusion/physical-optimizer/src/coalesce_batches.rs b/datafusion/physical-optimizer/src/coalesce_batches.rs index 5cf2c877c61a..481b63b7a134 100644 --- a/datafusion/physical-optimizer/src/coalesce_batches.rs +++ b/datafusion/physical-optimizer/src/coalesce_batches.rs @@ -26,8 +26,11 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_physical_expr::Partitioning; use datafusion_physical_plan::{ - coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, - repartition::RepartitionExec, ExecutionPlan, + coalesce_batches::CoalesceBatchesExec, + filter::FilterExec, + joins::{GraceHashJoinExec, HashJoinExec}, + repartition::RepartitionExec, + ExecutionPlan, }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -62,6 +65,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { // See https://github.com/apache/datafusion/issues/139 let wrap_in_coalesce = plan_any.downcast_ref::().is_some() || plan_any.downcast_ref::().is_some() + || plan_any.downcast_ref::().is_some() // Don't need to add CoalesceBatchesExec after a round robin RepartitionExec || plan_any .downcast_ref::() diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 898386e2f988..4173afbdc338 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -48,7 +48,7 @@ use datafusion_physical_plan::aggregates::{ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, + CrossJoinExec, GraceHashJoinExec, HashJoinExec, PartitionMode, SortMergeJoinExec, }; use datafusion_physical_plan::projection::{ProjectionExec, ProjectionExpr}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -348,7 +348,57 @@ pub fn adjust_input_keys_ordering( // Can not satisfy, clear the current requirements and generate new empty requirements requirements.data.clear(); } + PartitionMode::PartitionedSpillable => { + // For partitioned spillable, use the same logic as regular partitioned + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + new_conditions.0, + filter.clone(), + join_type, + // TODO: although projection is not used in the join here, because projection pushdown is after enforce_distribution. Maybe we need to handle it later. Same as filter. + projection.clone(), + PartitionMode::PartitionedSpillable, + *null_equality, + ) + .map(|e| Arc::new(e) as _) + }; + return reorder_partitioned_join_keys( + requirements, + on, + &[], + &join_constructor, + ) + .map(Transformed::yes); + } } + } else if let Some(grace_join) = plan.as_any().downcast_ref::() { + let join_constructor = |new_conditions: ( + Vec<(PhysicalExprRef, PhysicalExprRef)>, + Vec, + )| { + GraceHashJoinExec::try_new( + Arc::clone(grace_join.left()), + Arc::clone(grace_join.right()), + new_conditions.0, + grace_join.filter().cloned(), + grace_join.join_type(), + grace_join.projection.clone(), + grace_join.null_equality(), + ) + .map(|e| Arc::new(e) as _) + }; + return reorder_partitioned_join_keys( + requirements, + grace_join.on(), + &[], + &join_constructor, + ) + .map(Transformed::yes); } else if let Some(CrossJoinExec { left, .. }) = plan.as_any().downcast_ref::() { @@ -656,6 +706,30 @@ pub fn reorder_join_keys_to_inputs( )?)); } } + } else if let Some(grace_join) = plan_any.downcast_ref::() { + let (join_keys, positions) = reorder_current_join_keys( + extract_join_keys(grace_join.on()), + Some(grace_join.left().output_partitioning()), + Some(grace_join.right().output_partitioning()), + grace_join.left().equivalence_properties(), + grace_join.right().equivalence_properties(), + ); + if positions.is_some_and(|idxs| !idxs.is_empty()) { + let JoinKeyPairs { + left_keys, + right_keys, + } = join_keys; + let new_join_on = new_join_conditions(&left_keys, &right_keys); + return Ok(Arc::new(GraceHashJoinExec::try_new( + Arc::clone(grace_join.left()), + Arc::clone(grace_join.right()), + new_join_on, + grace_join.filter().cloned(), + grace_join.join_type(), + grace_join.projection.clone(), + grace_join.null_equality(), + )?)); + } } else if let Some(SortMergeJoinExec { left, right, diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6e4e78486612..ce1137d296f1 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -40,7 +40,9 @@ use datafusion_physical_plan::filter::FilterExec; use datafusion_physical_plan::joins::utils::{ calculate_join_output_ordering, ColumnIndex, }; -use datafusion_physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; +use datafusion_physical_plan::joins::{ + GraceHashJoinExec, HashJoinExec, SortMergeJoinExec, +}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::sort::SortExec; @@ -381,7 +383,9 @@ fn pushdown_requirement_to_children( Ok(None) } } else if let Some(hash_join) = plan.as_any().downcast_ref::() { - handle_hash_join(hash_join, parent_required) + handle_hash_like_join(hash_join, parent_required) + } else if let Some(grace_join) = plan.as_any().downcast_ref::() { + handle_hash_like_join(grace_join, parent_required) } else { handle_custom_pushdown(plan, parent_required, maintains_input_order) } @@ -698,10 +702,71 @@ fn handle_custom_pushdown( } } -// For hash join we only maintain the input order for the right child +trait HashJoinLike { + fn maintains_input_order(&self) -> Vec; + fn projection(&self) -> &Option>; + fn children(&self) -> Vec<&Arc>; + fn join_type(&self) -> &JoinType; + fn left(&self) -> &Arc; + fn right(&self) -> &Arc; +} + +impl HashJoinLike for HashJoinExec { + fn maintains_input_order(&self) -> Vec { + ExecutionPlan::maintains_input_order(self) + } + + fn projection(&self) -> &Option> { + &self.projection + } + + fn children(&self) -> Vec<&Arc> { + ExecutionPlan::children(self) + } + + fn join_type(&self) -> &JoinType { + self.join_type() + } + + fn left(&self) -> &Arc { + self.left() + } + + fn right(&self) -> &Arc { + self.right() + } +} + +impl HashJoinLike for GraceHashJoinExec { + fn maintains_input_order(&self) -> Vec { + ExecutionPlan::maintains_input_order(self) + } + + fn projection(&self) -> &Option> { + &self.projection + } + + fn children(&self) -> Vec<&Arc> { + ExecutionPlan::children(self) + } + + fn join_type(&self) -> &JoinType { + self.join_type() + } + + fn left(&self) -> &Arc { + self.left() + } + + fn right(&self) -> &Arc { + self.right() + } +} + +// For hash-based joins we only maintain the input order for the right child // for join type: Inner, Right, RightSemi, RightAnti -fn handle_hash_join( - plan: &HashJoinExec, +fn handle_hash_like_join( + plan: &J, parent_required: OrderingRequirements, ) -> Result>>> { // If the plan has no children or does not maintain the right side ordering, @@ -723,7 +788,7 @@ fn handle_hash_join( .collect(); let column_indices = build_join_column_index(plan); - let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + let projected_indices: Vec<_> = if let Some(projection) = plan.projection() { projection.iter().map(|&i| &column_indices[i]).collect() } else { column_indices.iter().collect() @@ -770,9 +835,9 @@ fn handle_hash_join( } } -// this function is used to build the column index for the hash join +// this function is used to build the column index for hash-based joins so we can // push down sort requirements to the right child -fn build_join_column_index(plan: &HashJoinExec) -> Vec { +fn build_join_column_index(plan: &J) -> Vec { let map_fields = |schema: SchemaRef, side: JoinSide| { schema .fields() diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index c2cfca681f66..a587ec027d89 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -27,14 +27,14 @@ use crate::PhysicalOptimizerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, JoinSide, JoinType}; +use datafusion_common::{internal_err, DataFusionError, JoinSide, JoinType}; use datafusion_expr_common::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::execution_plan::EmissionType; -use datafusion_physical_plan::joins::utils::ColumnIndex; +use datafusion_physical_plan::joins::utils::{check_join_is_valid, ColumnIndex}; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + CrossJoinExec, GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; @@ -134,12 +134,16 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; + let enable_spillable = config.enable_spillable_hash_join; + let enable_grace = config.enable_grace_hash_join; new_plan .transform_up(|plan| { statistical_join_selection_subrule( plan, collect_threshold_byte_size, collect_threshold_num_rows, + enable_spillable, + enable_grace, ) }) .data() @@ -169,6 +173,13 @@ pub(crate) fn try_collect_left( let left = hash_join.left(); let right = hash_join.right(); + // Skip collect-left rewrite if the join currently has inconsistent schemas (e.g. required + // columns were projected away temporarily). This mirrors the legacy hash join behavior where + // collect-left is only attempted once the join inputs are fully valid. + if check_join_is_valid(&left.schema(), &right.schema(), hash_join.on()).is_err() { + return Ok(None); + } + let left_can_collect = ignore_threshold || supports_collect_by_thresholds( &**left, @@ -187,33 +198,23 @@ pub(crate) fn try_collect_left( if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) + match hash_join.swap_inputs(PartitionMode::CollectLeft) { + Ok(plan) => Ok(Some(plan)), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } } else { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))) + build_collect_left_exec(hash_join, left, right) } } - (true, false) => Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))), + (true, false) => build_collect_left_exec(hash_join, left, right), (false, true) => { if hash_join.join_type().supports_swap() { - hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) + match hash_join.swap_inputs(PartitionMode::CollectLeft) { + Ok(plan) => Ok(Some(plan)), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } } else { Ok(None) } @@ -222,6 +223,35 @@ pub(crate) fn try_collect_left( } } +fn is_missing_join_columns(err: &DataFusionError) -> bool { + matches!( + err, + DataFusionError::Plan(msg) + if msg.contains("The left or right side of the join does not have all columns") + ) +} + +fn build_collect_left_exec( + hash_join: &HashJoinExec, + left: &Arc, + right: &Arc, +) -> Result>> { + match HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + PartitionMode::CollectLeft, + hash_join.null_equality(), + ) { + Ok(exec) => Ok(Some(Arc::new(exec))), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } +} + /// Creates a partitioned hash join execution plan, swapping inputs if beneficial. /// /// Checks if the join order should be swapped based on the join type and input statistics. @@ -229,12 +259,38 @@ pub(crate) fn try_collect_left( /// creates a standard partitioned hash join. pub(crate) fn partitioned_hash_join( hash_join: &HashJoinExec, + enable_spillable: bool, + enable_grace: bool, ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? - { - hash_join.swap_inputs(PartitionMode::Partitioned) + let partition_mode = if enable_spillable { + PartitionMode::PartitionedSpillable + } else { + PartitionMode::Partitioned + }; + + let should_swap = hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)?; + if enable_grace && matches!(partition_mode, PartitionMode::PartitionedSpillable) { + let grace = Arc::new(GraceHashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + hash_join.null_equality(), + )?); + return if should_swap { + grace.swap_inputs(partition_mode) + } else { + Ok(grace) + }; + } + + if should_swap { + hash_join.swap_inputs(partition_mode) } else { Ok(Arc::new(HashJoinExec::try_new( Arc::clone(left), @@ -243,7 +299,7 @@ pub(crate) fn partitioned_hash_join( hash_join.filter().cloned(), hash_join.join_type(), hash_join.projection.clone(), - PartitionMode::Partitioned, + partition_mode, hash_join.null_equality(), )?)) } @@ -255,60 +311,84 @@ fn statistical_join_selection_subrule( plan: Arc, collect_threshold_byte_size: usize, collect_threshold_num_rows: usize, + enable_spillable: bool, + enable_grace: bool, ) -> Result>> { - let transformed = - if let Some(hash_join) = plan.as_any().downcast_ref::() { - match hash_join.partition_mode() { - PartitionMode::Auto => try_collect_left( - hash_join, - false, - collect_threshold_byte_size, - collect_threshold_num_rows, - )? + let transformed = if let Some(hash_join) = + plan.as_any().downcast_ref::() + { + match hash_join.partition_mode() { + PartitionMode::Auto => try_collect_left( + hash_join, + false, + collect_threshold_byte_size, + collect_threshold_num_rows, + )? + .map_or_else( + || { + partitioned_hash_join(hash_join, enable_spillable, enable_grace) + .map(Some) + }, + |v| Ok(Some(v)), + )?, + PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? .map_or_else( - || partitioned_hash_join(hash_join).map(Some), + || { + partitioned_hash_join(hash_join, enable_spillable, enable_grace) + .map(Some) + }, |v| Ok(Some(v)), )?, - PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? - .map_or_else( - || partitioned_hash_join(hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::Partitioned => { - let left = hash_join.left(); - let right = hash_join.right(); - if hash_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - hash_join - .swap_inputs(PartitionMode::Partitioned) - .map(Some)? - } else { - None - } + PartitionMode::Partitioned => { + let left = hash_join.left(); + let right = hash_join.right(); + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + hash_join + .swap_inputs(PartitionMode::Partitioned) + .map(Some)? + } else { + None } } - } else if let Some(cross_join) = plan.as_any().downcast_ref::() { - let left = cross_join.left(); - let right = cross_join.right(); - if should_swap_join_order(&**left, &**right)? { - cross_join.swap_inputs().map(Some)? - } else { - None - } - } else if let Some(nl_join) = plan.as_any().downcast_ref::() { - let left = nl_join.left(); - let right = nl_join.right(); - if nl_join.join_type().supports_swap() - && should_swap_join_order(&**left, &**right)? - { - nl_join.swap_inputs().map(Some)? - } else { - None + PartitionMode::PartitionedSpillable => { + println!("Using PartitionMode::PartitionedSpillable"); + // For partitioned spillable, use the same logic as regular partitioned + let left = hash_join.left(); + let right = hash_join.right(); + if hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + hash_join + .swap_inputs(PartitionMode::PartitionedSpillable) + .map(Some)? + } else { + None + } } + } + } else if let Some(cross_join) = plan.as_any().downcast_ref::() { + let left = cross_join.left(); + let right = cross_join.right(); + if should_swap_join_order(&**left, &**right)? { + cross_join.swap_inputs().map(Some)? } else { None - }; + } + } else if let Some(nl_join) = plan.as_any().downcast_ref::() { + let left = nl_join.left(); + let right = nl_join.right(); + if nl_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)? + { + nl_join.swap_inputs().map(Some)? + } else { + None + } + } else { + None + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) @@ -522,6 +602,9 @@ pub(crate) fn swap_join_according_to_unboundedness( (PartitionMode::CollectLeft, _) => { hash_join.swap_inputs(PartitionMode::CollectLeft) } + (PartitionMode::PartitionedSpillable, _) => { + hash_join.swap_inputs(PartitionMode::PartitionedSpillable) + } (PartitionMode::Auto, _) => { // Use `PartitionMode::Partitioned` as default if `Auto` is selected. hash_join.swap_inputs(PartitionMode::Partitioned) diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index a21d91c219aa..110f02bf22de 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -38,6 +38,7 @@ workspace = true force_hash_collisions = [] tokio_coop = [] tokio_coop_fallback = [] +hybrid_hash_join_scheduler = [] [lib] name = "datafusion_physical_plan" diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs new file mode 100644 index 000000000000..47530688c0b9 --- /dev/null +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -0,0 +1,1166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::execution_plan::{boundedness_from_children, EmissionType}; +use crate::filter_pushdown::{ + ChildPushdownResult, FilterDescription, FilterPushdownPhase, + FilterPushdownPropagation, +}; +use crate::joins::utils::{reorder_output_after_swap, swap_join_projection, OnceFut}; +use crate::joins::{JoinOn, JoinOnRef, PartitionMode}; +use crate::projection::{ + try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, + ProjectionExec, +}; +use crate::{ + common::can_project, + joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, + JoinFilter, + }, + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, Statistics, +}; +use crate::{ExecutionPlanProperties, SpillManager}; +use std::fmt; +use std::fmt::Formatter; +use std::sync::Arc; +use std::{any::Any, vec}; + +use arrow::array::UInt32Array; +use arrow::compute::{concat_batches, take}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{ + internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::{ + join_equivalence_properties, ProjectionMapping, +}; +use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; + +use crate::joins::grace_hash_join::stream::{GraceHashJoinStream, SpillFut}; +use crate::metrics::SpillMetrics; +use crate::spill::spill_manager::SpillLocation; +use ahash::RandomState; +use datafusion_common::hash_utils::create_hashes; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::StreamExt; + +/// Hard-coded seed to ensure hash values from the hash join differ from `RepartitionExec`, avoiding collisions. +const HASH_JOIN_SEED: RandomState = + RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); + +pub struct GraceHashJoinExec { + /// left (build) side which gets hashed + pub left: Arc, + /// right (probe) side which are filtered by the hash table + pub right: Arc, + /// Set of equijoin columns from the relations: `(left_col, right_col)` + pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + /// Filters which are applied while finding matching rows + pub filter: Option, + /// How the join is performed (`OUTER`, `INNER`, etc) + pub join_type: JoinType, + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + join_schema: SchemaRef, + /// Shared the `RandomState` for the hashing algorithm + random_state: RandomState, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// The projection indices of the columns in the output schema of join + pub projection: Option>, + /// Information of index and left / right placement of columns + column_indices: Vec, + /// The equality null-handling behavior of the join algorithm. + pub null_equality: NullEquality, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, + /// Indicates whether dynamic filter pushdown is enabled for this join. + dynamic_filter_enabled: bool, +} + +impl fmt::Debug for GraceHashJoinExec { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("HashJoinExec") + .field("left", &self.left) + .field("right", &self.right) + .field("on", &self.on) + .field("filter", &self.filter) + .field("join_type", &self.join_type) + .field("join_schema", &self.join_schema) + .field("random_state", &self.random_state) + .field("metrics", &self.metrics) + .field("projection", &self.projection) + .field("column_indices", &self.column_indices) + .field("null_equality", &self.null_equality) + .field("cache", &self.cache) + // Intentionally omit dynamic_filter_enabled to keep debug output stable + .finish() + } +} + +impl EmbeddedProjection for GraceHashJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + +impl GraceHashJoinExec { + /// Tries to create a new [GraceHashJoinExec]. + /// + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + #[allow(clippy::too_many_arguments)] + pub fn try_new( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + projection: Option>, + null_equality: NullEquality, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return plan_err!("On constraints in HashJoinExec should be non-empty"); + } + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let (join_schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + let random_state = HASH_JOIN_SEED; + + let join_schema = Arc::new(join_schema); + + // check if the projection is valid + can_project(&join_schema, projection.as_ref())?; + + let cache = Self::compute_properties( + &left, + &right, + Arc::clone(&join_schema), + *join_type, + &on, + projection.as_ref(), + )?; + let metrics = ExecutionPlanMetricsSet::new(); + // Initialize both dynamic filter and bounds accumulator to None + // They will be set later if dynamic filtering is enabled + Ok(GraceHashJoinExec { + left, + right, + on, + filter, + join_type: *join_type, + join_schema, + random_state, + metrics, + projection, + column_indices, + null_equality, + cache, + dynamic_filter_enabled: false, + }) + } + + fn create_dynamic_filter(on: &JoinOn) -> Arc { + // Extract the right-side keys (probe side keys) from the `on` clauses + // Dynamic filter will be created from build side values (left side) and applied to probe side (right side) + let right_keys: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); + // Initialize with a placeholder expression (true) that will be updated when the hash table is built + Arc::new(DynamicFilterPhysicalExpr::new(right_keys, lit(true))) + } + + /// left (build) side which gets hashed + pub fn left(&self) -> &Arc { + &self.left + } + + /// right (probe) side which are filtered by the hash table + pub fn right(&self) -> &Arc { + &self.right + } + + /// Set of common columns used to join on + pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] { + &self.on + } + + /// Filters applied before join output + pub fn filter(&self) -> Option<&JoinFilter> { + self.filter.as_ref() + } + + /// How the join is performed + pub fn join_type(&self) -> &JoinType { + &self.join_type + } + + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + pub fn join_schema(&self) -> &SchemaRef { + &self.join_schema + } + + /// Get null_equality + pub fn null_equality(&self) -> NullEquality { + self.null_equality + } + + /// Calculate order preservation flags for this hash join. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi + | JoinType::RightMark + ), + ] + } + + /// Get probe side information for the hash join. + pub fn probe_side() -> JoinSide { + // In current implementation right side is always probe side. + JoinSide::Right + } + + /// Return whether the join contains a projection + pub fn contains_projection(&self) -> bool { + self.projection.is_some() + } + + /// Return new instance of [HashJoinExec] with the given projection. + pub fn with_projection(&self, projection: Option>) -> Result { + // check if the projection is valid + can_project(&self.schema(), projection.as_ref())?; + let projection = match projection { + Some(projection) => match &self.projection { + Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), + None => Some(projection), + }, + None => None, + }; + Self::try_new( + Arc::clone(&self.left), + Arc::clone(&self.right), + self.on.clone(), + self.filter.clone(), + &self.join_type, + projection, + self.null_equality, + ) + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + left: &Arc, + right: &Arc, + schema: SchemaRef, + join_type: JoinType, + on: JoinOnRef, + projection: Option<&Vec>, + ) -> Result { + // Calculate equivalence properties: + let mut eq_properties = join_equivalence_properties( + left.equivalence_properties().clone(), + right.equivalence_properties().clone(), + &join_type, + Arc::clone(&schema), + &Self::maintains_input_order(join_type), + Some(Self::probe_side()), + on, + )?; + + let mut output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type)?; + let emission_type = if left.boundedness().is_unbounded() { + EmissionType::Final + } else if right.pipeline_behavior() == EmissionType::Incremental { + match join_type { + // If we only need to generate matched rows from the probe side, + // we can emit rows incrementally. + JoinType::Inner + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightMark => EmissionType::Incremental, + // If we need to generate unmatched rows from the *build side*, + // we need to emit them at the end. + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::Full => EmissionType::Both, + } + } else { + right.pipeline_behavior() + }; + + // If contains projection, update the PlanProperties. + if let Some(projection) = projection { + // construct a map from the input expressions to the output expression of the Projection + let projection_mapping = + ProjectionMapping::from_indices(projection, &schema)?; + let out_schema = project_schema(&schema, Some(projection))?; + output_partitioning = + output_partitioning.project(&projection_mapping, &eq_properties); + eq_properties = eq_properties.project(&projection_mapping, out_schema); + } + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + emission_type, + boundedness_from_children([left, right]), + )) + } + + /// Returns a new `ExecutionPlan` that computes the same join as this one, + /// with the left and right inputs swapped using the specified + /// `partition_mode`. + /// + /// # Notes: + /// + /// This function is public so other downstream projects can use it to + /// construct `HashJoinExec` with right side as the build side. + /// + /// For using this interface directly, please refer to below: + /// + /// Hash join execution may require specific input partitioning (for example, + /// the left child may have a single partition while the right child has multiple). + /// + /// Calling this function on join nodes whose children have already been repartitioned + /// (e.g., after a `RepartitionExec` has been inserted) may break the partitioning + /// requirements of the hash join. Therefore, ensure you call this function + /// before inserting any repartitioning operators on the join's children. + /// + /// In DataFusion's default SQL interface, this function is used by the `JoinSelection` + /// physical optimizer rule to determine a good join order, which is + /// executed before the `EnforceDistribution` rule (the rule that may + /// insert `RepartitionExec` operators). + pub fn swap_inputs( + &self, + _partition_mode: PartitionMode, + ) -> Result> { + let left = self.left(); + let right = self.right(); + let new_join = GraceHashJoinExec::try_new( + Arc::clone(right), + Arc::clone(left), + self.on() + .iter() + .map(|(l, r)| (Arc::clone(r), Arc::clone(l))) + .collect(), + self.filter().map(JoinFilter::swap), + &self.join_type().swap(), + swap_join_projection( + left.schema().fields().len(), + right.schema().fields().len(), + self.projection.as_ref(), + self.join_type(), + ), + self.null_equality(), + )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again + if matches!( + self.join_type(), + JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftAnti + | JoinType::RightAnti + ) || self.projection.is_some() + { + Ok(Arc::new(new_join)) + } else { + reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema()) + } + } +} + +impl DisplayAs for GraceHashJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + let display_projections = if self.contains_projection() { + format!( + ", projection=[{}]", + self.projection + .as_ref() + .unwrap() + .iter() + .map(|index| format!( + "{}@{}", + self.join_schema.fields().get(*index).unwrap().name(), + index + )) + .collect::>() + .join(", ") + ) + } else { + "".to_string() + }; + let on = self + .on + .iter() + .map(|(c1, c2)| format!("({c1}, {c2})")) + .collect::>() + .join(", "); + write!( + f, + "GraceHashJoinExec: join_type={:?}, on=[{}]{}{}", + self.join_type, on, display_filter, display_projections, + ) + } + DisplayFormatType::TreeRender => { + let on = self + .on + .iter() + .map(|(c1, c2)| { + format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref())) + }) + .collect::>() + .join(", "); + + if *self.join_type() != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + + writeln!(f, "on={on}")?; + + if let Some(filter) = self.filter.as_ref() { + writeln!(f, "filter={filter}")?; + } + + Ok(()) + } + } + } +} + +impl ExecutionPlan for GraceHashJoinExec { + fn name(&self) -> &'static str { + "GraceHashJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn required_input_distribution(&self) -> Vec { + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + vec![ + Distribution::HashPartitioned(left_expr), + Distribution::HashPartitioned(right_expr), + ] + } + + // For [JoinType::Inner] and [JoinType::RightSemi] in hash joins, the probe phase initiates by + // applying the hash function to convert the join key(s) in each row into a hash value from the + // probe side table in the order they're arranged. The hash value is used to look up corresponding + // entries in the hash table that was constructed from the build side table during the build phase. + // + // Because of the immediate generation of result rows once a match is found, + // the output of the join tends to follow the order in which the rows were read from + // the probe side table. This is simply due to the sequence in which the rows were processed. + // Hence, it appears that the hash join is preserving the order of the probe side. + // + // Meanwhile, in the case of a [JoinType::RightAnti] hash join, + // the unmatched rows from the probe side are also kept in order. + // This is because the **`RightAnti`** join is designed to return rows from the right + // (probe side) table that have no match in the left (build side) table. Because the rows + // are processed sequentially in the probe phase, and unmatched rows are directly output + // as results, these results tend to retain the order of the probe side table. + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + /// Creates a new HashJoinExec with different children while preserving configuration. + /// + /// This method is called during query optimization when the optimizer creates new + /// plan nodes. Importantly, it creates a fresh bounds_accumulator via `try_new` + /// rather than cloning the existing one because partitioning may have changed. + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(GraceHashJoinExec { + left: Arc::clone(&children[0]), + right: Arc::clone(&children[1]), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + random_state: self.random_state.clone(), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: Self::compute_properties( + &children[0], + &children[1], + Arc::clone(&self.join_schema), + self.join_type, + &self.on, + self.projection.as_ref(), + )?, + // Preserve dynamic filter enablement; state will be refreshed as needed + dynamic_filter_enabled: self.dynamic_filter_enabled, + })) + } + + fn reset_state(self: Arc) -> Result> { + Ok(Arc::new(GraceHashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + random_state: self.random_state.clone(), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + // Reset dynamic filter state to initial configuration + dynamic_filter_enabled: false, + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let left_partitions = self.left.output_partitioning().partition_count(); + let right_partitions = self.right.output_partitioning().partition_count(); + + if left_partitions != right_partitions { + return internal_err!( + "Invalid GraceHashJoinExec, partition count mismatch {left_partitions}!={right_partitions},\ + consider using RepartitionExec" + ); + } + + let enable_dynamic_filter_pushdown = self.dynamic_filter_enabled; + + let join_metrics = Arc::new(BuildProbeJoinMetrics::new(partition, &self.metrics)); + + let left = self.left.execute(partition, Arc::clone(&context))?; + let left_schema = Arc::clone(&self.left.schema()); + let on_left = self + .on + .iter() + .map(|(left_expr, _)| Arc::clone(left_expr)) + .collect::>(); + + let right = self.right.execute(partition, Arc::clone(&context))?; + let right_schema = Arc::clone(&self.right.schema()); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + + let spill_left = Arc::new(SpillManager::new( + Arc::clone(&context.runtime_env()), + SpillMetrics::new(&self.metrics, partition), + Arc::clone(&left_schema), + )); + let spill_right = Arc::new(SpillManager::new( + Arc::clone(&context.runtime_env()), + SpillMetrics::new(&self.metrics, partition), + Arc::clone(&right_schema), + )); + + // update column indices to reflect the projection + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + + let random_state = self.random_state.clone(); + let on = self.on.clone(); + let spill_left_clone = Arc::clone(&spill_left); + let spill_right_clone = Arc::clone(&spill_right); + let join_metrics_clone = Arc::clone(&join_metrics); + let spill_fut = OnceFut::new(async move { + let (left_idx, right_idx) = partition_and_spill( + random_state, + on, + left, + right, + join_metrics_clone, + enable_dynamic_filter_pushdown, + left_partitions, + spill_left_clone, + spill_right_clone, + partition, + ) + .await?; + Ok(SpillFut::new(partition, left_idx, right_idx)) + }); + + let left_input_schema = self.left.schema(); + let right_input_schema = self.right.schema(); + + Ok(Box::pin(GraceHashJoinStream::new( + self.schema(), + left_input_schema, + right_input_schema, + spill_fut, + spill_left, + spill_right, + on_left, + on_right, + self.projection.clone(), + self.filter.clone(), + self.join_type, + column_indices_after_projection, + join_metrics, + context, + ))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.partition_statistics(None) + } + + fn partition_statistics(&self, partition: Option) -> Result { + if partition.is_some() { + return Ok(Statistics::new_unknown(&self.schema())); + } + let stats = estimate_join_statistics( + self.left.partition_statistics(None)?, + self.right.partition_statistics(None)?, + self.on.clone(), + &self.join_type, + &self.join_schema, + )?; + // Project statistics if there is a projection + Ok(stats.project(self.projection.as_ref())) + } + + /// Tries to push `projection` down through `hash_join`. If possible, performs the + /// pushdown and returns a new [`HashJoinExec`] as the top plan which has projections + /// as its children. Otherwise, returns `None`. + fn try_swapping_with_projection( + &self, + projection: &ProjectionExec, + ) -> Result>> { + // TODO: currently if there is projection in GraceHashJoinExec, we can't push down projection to left or right input. Maybe we can pushdown the mixed projection later. + if self.contains_projection() { + return Ok(None); + } + + if let Some(JoinData { + projected_left_child, + projected_right_child, + join_filter, + join_on, + }) = try_pushdown_through_join( + projection, + self.left(), + self.right(), + self.on(), + self.schema(), + self.filter(), + )? { + Ok(Some(Arc::new(GraceHashJoinExec::try_new( + Arc::new(projected_left_child), + Arc::new(projected_right_child), + join_on, + join_filter, + self.join_type(), + // Returned early if projection is not None + None, + self.null_equality, + )?))) + } else { + try_embed_projection(projection, self) + } + } + + fn gather_filters_for_pushdown( + &self, + phase: FilterPushdownPhase, + parent_filters: Vec>, + config: &ConfigOptions, + ) -> Result { + // Other types of joins can support *some* filters, but restrictions are complex and error prone. + // For now we don't support them. + // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + // See https://github.com/apache/datafusion/issues/16973 for tracking. + if self.join_type != JoinType::Inner { + return Ok(FilterDescription::all_unsupported( + &parent_filters, + &self.children(), + )); + } + + // Get basic filter descriptions for both children + let left_child = crate::filter_pushdown::ChildFilterDescription::from_child( + &parent_filters, + self.left(), + )?; + let mut right_child = crate::filter_pushdown::ChildFilterDescription::from_child( + &parent_filters, + self.right(), + )?; + + // Add dynamic filters in Post phase if enabled + if matches!(phase, FilterPushdownPhase::Post) + && config.optimizer.enable_dynamic_filter_pushdown + { + // Add actual dynamic filter to right side (probe side) + let dynamic_filter = Self::create_dynamic_filter(&self.on); + right_child = right_child.with_self_filter(dynamic_filter); + } + + Ok(FilterDescription::new() + .with_child(left_child) + .with_child(right_child)) + } + + fn handle_child_pushdown_result( + &self, + _phase: FilterPushdownPhase, + child_pushdown_result: ChildPushdownResult, + _config: &ConfigOptions, + ) -> Result>> { + // Note: this check shouldn't be necessary because we already marked all parent filters as unsupported for + // non-inner joins in `gather_filters_for_pushdown`. + // However it's a cheap check and serves to inform future devs touching this function that they need to be really + // careful pushing down filters through non-inner joins. + if self.join_type != JoinType::Inner { + // Other types of joins can support *some* filters, but restrictions are complex and error prone. + // For now we don't support them. + // See the logical optimizer rules for more details: datafusion/optimizer/src/push_down_filter.rs + return Ok(FilterPushdownPropagation::all_unsupported( + child_pushdown_result, + )); + } + + let mut result = FilterPushdownPropagation::if_any(child_pushdown_result.clone()); + assert_eq!(child_pushdown_result.self_filters.len(), 2); // Should always be 2, we have 2 children + let right_child_self_filters = &child_pushdown_result.self_filters[1]; // We only push down filters to the right child + // We expect 0 or 1 self filters + if let Some(filter) = right_child_self_filters.first() { + // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said + // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating + let predicate = Arc::clone(&filter.predicate); + if Arc::downcast::(predicate).is_ok() { + // We successfully pushed down our self filter - we need to make a new node with the dynamic filter + let new_node = Arc::new(GraceHashJoinExec { + left: Arc::clone(&self.left), + right: Arc::clone(&self.right), + on: self.on.clone(), + filter: self.filter.clone(), + join_type: self.join_type, + join_schema: Arc::clone(&self.join_schema), + random_state: self.random_state.clone(), + metrics: ExecutionPlanMetricsSet::new(), + projection: self.projection.clone(), + column_indices: self.column_indices.clone(), + null_equality: self.null_equality, + cache: self.cache.clone(), + dynamic_filter_enabled: true, + }); + result = result.with_updated_node(new_node as Arc); + } + } + Ok(result) + } +} + +#[allow(clippy::too_many_arguments)] +pub async fn partition_and_spill( + random_state: RandomState, + on: Vec<(PhysicalExprRef, PhysicalExprRef)>, + mut left_stream: SendableRecordBatchStream, + mut right_stream: SendableRecordBatchStream, + join_metrics: Arc, + enable_dynamic_filter_pushdown: bool, + partition_count: usize, + spill_left: Arc, + spill_right: Arc, + partition: usize, +) -> Result<(Vec, Vec)> { + let on_left: Vec<_> = on.iter().map(|(l, _)| Arc::clone(l)).collect(); + let on_right: Vec<_> = on.iter().map(|(_, r)| Arc::clone(r)).collect(); + + // LEFT side partitioning + let left_index = partition_and_spill_one_side( + &mut left_stream, + &on_left, + &random_state, + partition_count, + spill_left, + &format!("left_{partition}"), + &join_metrics, + enable_dynamic_filter_pushdown, + ) + .await?; + + // RIGHT side partitioning + let right_index = partition_and_spill_one_side( + &mut right_stream, + &on_right, + &random_state, + partition_count, + spill_right, + &format!("right_{partition}"), + &join_metrics, + enable_dynamic_filter_pushdown, + ) + .await?; + Ok((left_index, right_index)) +} + +#[allow(clippy::too_many_arguments)] +async fn partition_and_spill_one_side( + input: &mut SendableRecordBatchStream, + on_exprs: &[PhysicalExprRef], + random_state: &RandomState, + partition_count: usize, + spill_manager: Arc, + spilling_request_msg: &str, + join_metrics: &BuildProbeJoinMetrics, + _enable_dynamic_filter_pushdown: bool, +) -> Result> { + let mut partitions: Vec = (0..partition_count) + .map(|_| PartitionWriter::new(Arc::clone(&spill_manager))) + .collect(); + + let mut buffered_batches = Vec::new(); + + let schema = input.schema(); + while let Some(batch) = input.next().await { + let batch = batch?; + if batch.num_rows() == 0 { + continue; + } + join_metrics.build_input_batches.add(1); + join_metrics.build_input_rows.add(batch.num_rows()); + buffered_batches.push(batch); + } + if buffered_batches.is_empty() { + return Ok(Vec::new()); + } + // Create single batch to reduce number of spilled files + let single_batch = concat_batches(&schema, &buffered_batches)?; + let num_rows = single_batch.num_rows(); + if num_rows == 0 { + return Ok(Vec::new()); + } + + // Calculate hashes + let keys = on_exprs + .iter() + .map(|c| c.evaluate(&single_batch)?.into_array(num_rows)) + .collect::>>()?; + + let mut hashes = vec![0u64; num_rows]; + create_hashes(&keys, random_state, &mut hashes)?; + + // Spread to partitions + let mut indices: Vec> = vec![Vec::new(); partition_count]; + for (row, h) in hashes.iter().enumerate() { + let bucket = (*h as usize) % partition_count; + indices[bucket].push(row as u32); + } + + // Collect and spill + for (i, idxs) in indices.into_iter().enumerate() { + if idxs.is_empty() { + continue; + } + + let idx_array = UInt32Array::from(idxs); + let taken = single_batch + .columns() + .iter() + .map(|c| take(c.as_ref(), &idx_array, None)) + .collect::>>()?; + + let part_batch = RecordBatch::try_new(single_batch.schema(), taken)?; + // We need unique name for spilling + let request_msg = format!("grace_partition_{spilling_request_msg}_{i}"); + partitions[i].spill_batch_auto(&part_batch, &request_msg)?; + } + + // Prepare indexes + let mut result = Vec::with_capacity(partitions.len()); + for writer in partitions.into_iter() { + result.push(writer.finish()?); + } + // println!("spill_manager {:?}", spill_manager.metrics); + Ok(result) +} + +#[derive(Debug)] +pub struct PartitionWriter { + spill_manager: Arc, + chunks: Vec, +} + +impl PartitionWriter { + pub fn new(spill_manager: Arc) -> Self { + Self { + spill_manager, + chunks: vec![], + } + } + + pub fn spill_batch_auto( + &mut self, + batch: &RecordBatch, + request_msg: &str, + ) -> Result<()> { + let loc = self.spill_manager.spill_batch_auto(batch, request_msg)?; + self.chunks.push(loc); + Ok(()) + } + + pub fn finish(self) -> Result { + Ok(PartitionIndex { + chunks: self.chunks, + }) + } +} + +/// Describes a single partition of spilled data (used in GraceHashJoin). +/// +/// Each partition can consist of one or multiple chunks (batches) +/// that were spilled either to memory or to disk. +/// These chunks are later reloaded during the join phase. +/// +/// Example: +/// Partition 3 -> [ spill_chunk_3_0.arrow, spill_chunk_3_1.arrow ] +#[derive(Debug, Clone)] +pub struct PartitionIndex { + /// Collection of spill locations (each corresponds to one batch written + /// by [`PartitionWriter::spill_batch_auto`]) + pub chunks: Vec, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::TestMemoryExec; + use crate::{common, expressions::Column, repartition::RepartitionExec}; + + use crate::joins::HashJoinExec; + use arrow::array::{ArrayRef, Int32Array}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_physical_expr::Partitioning; + use futures::future; + + fn build_large_table( + a_name: &str, + b_name: &str, + c_name: &str, + n: usize, + ) -> Arc { + let a: ArrayRef = Arc::new(Int32Array::from_iter_values(1..=n as i32)); + let b: ArrayRef = + Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 2))); + let c: ArrayRef = + Arc::new(Int32Array::from_iter_values((1..=n as i32).map(|x| x * 10))); + + let schema = Arc::new(arrow::datatypes::Schema::new(vec![ + arrow::datatypes::Field::new( + a_name, + arrow::datatypes::DataType::Int32, + false, + ), + arrow::datatypes::Field::new( + b_name, + arrow::datatypes::DataType::Int32, + false, + ), + arrow::datatypes::Field::new( + c_name, + arrow::datatypes::DataType::Int32, + false, + ), + ])); + + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap(); + Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + #[tokio::test] + async fn simple_grace_hash_join() -> Result<()> { + let left = build_large_table("a1", "b1", "c1", 2000000); + let right = build_large_table("a2", "b2", "c2", 5000000); + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let (left_expr, right_expr): ( + Vec>, + Vec>, + ) = on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + let left_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(left, Partitioning::Hash(left_expr, 32))?, + ); + let right_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(right, Partitioning::Hash(right_expr, 32))?, + ); + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500_000_000, 1.0) + .build_arc()?; + let task_ctx = TaskContext::default().with_runtime(runtime); + let task_ctx = Arc::new(task_ctx); + + let join = GraceHashJoinExec::try_new( + Arc::clone(&left_repartitioned), + Arc::clone(&right_repartitioned), + on.clone(), + None, + &JoinType::Inner, + None, + NullEquality::NullEqualsNothing, + )?; + + let partition_count = right_repartitioned.output_partitioning().partition_count(); + let tasks: Vec<_> = (0..partition_count) + .map(|i| { + let ctx = Arc::clone(&task_ctx); + let s = join.execute(i, ctx).unwrap(); + async move { common::collect(s).await } + }) + .collect(); + + let results = future::join_all(tasks).await; + let mut batches = Vec::new(); + for r in results { + let mut v = r?; + v.retain(|b| b.num_rows() > 0); + batches.extend(v); + } + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 1_000_000); + + // print_batches(&*batches).unwrap(); + // Asserting that operator-level reservation attempting to overallocate + // assert_contains!( + // err.to_string(), + // "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:\n HashJoinInput" + // ); + // + // assert_contains!( + // err.to_string(), + // "Failed to allocate additional 120.0 B for HashJoinInput" + // ); + Ok(()) + } + + #[tokio::test] + async fn simple_hash_join() -> Result<()> { + let left = build_large_table("a1", "b1", "c1", 2000000); + let right = build_large_table("a2", "b2", "c2", 5000000); + let on = vec![( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let (left_expr, right_expr): ( + Vec>, + Vec>, + ) = on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); + let left_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(left, Partitioning::Hash(left_expr, 32))?, + ); + let right_repartitioned: Arc = Arc::new( + RepartitionExec::try_new(right, Partitioning::Hash(right_expr, 32))?, + ); + let partition_count = left_repartitioned.output_partitioning().partition_count(); + + let join = HashJoinExec::try_new( + left_repartitioned, + right_repartitioned, + on, + None, + &JoinType::Inner, + None, + PartitionMode::Partitioned, + NullEquality::NullEqualsNothing, + )?; + + let task_ctx = Arc::new(TaskContext::default()); + let mut batches = vec![]; + for i in 0..partition_count { + let stream = join.execute(i, Arc::clone(&task_ctx))?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 1_000_000); + + // print_batches(&*batches).unwrap(); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/mod.rs b/datafusion/physical-plan/src/joins/grace_hash_join/mod.rs new file mode 100644 index 000000000000..55d7e2035e6c --- /dev/null +++ b/datafusion/physical-plan/src/joins/grace_hash_join/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`GraceHashJoinExec`] Partitioned Hash Join Operator + +pub use exec::GraceHashJoinExec; + +mod exec; +mod stream; diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs new file mode 100644 index 000000000000..da0de10a5188 --- /dev/null +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Stream implementation for Hash Join +//! +//! This module implements [`HashJoinStream`], the streaming engine for +//! [`super::HashJoinExec`]. See comments in [`HashJoinStream`] for more details. + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::joins::utils::OnceFut; +use crate::{ + joins::utils::{BuildProbeJoinMetrics, ColumnIndex, JoinFilter}, + ExecutionPlan, RecordBatchStream, SendableRecordBatchStream, SpillManager, +}; + +use crate::empty::EmptyExec; +use crate::joins::grace_hash_join::exec::PartitionIndex; +use crate::joins::{HashJoinExec, PartitionMode}; +use crate::test::TestMemoryExec; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{JoinType, NullEquality, Result}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::PhysicalExprRef; +use futures::{ready, Stream, StreamExt}; + +enum GraceJoinState { + /// Waiting for the partitioning phase (Phase 1) to finish + WaitPartitioning, + + /// Currently joining partition `current` + JoinPartition { + current: usize, + all_parts: Arc>, + current_stream: Option, + left_fut: Option>>, + right_fut: Option>>, + }, + + Done, +} + +pub struct GraceHashJoinStream { + schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, + spill_fut: OnceFut, + spill_left: Arc, + spill_right: Arc, + on_left: Vec, + on_right: Vec, + projection: Option>, + filter: Option, + join_type: JoinType, + column_indices: Vec, + join_metrics: Arc, + context: Arc, + state: GraceJoinState, +} + +#[derive(Debug, Clone)] +pub struct SpillFut { + left: Vec, + right: Vec, +} +impl SpillFut { + pub(crate) fn new( + _partition: usize, + left: Vec, + right: Vec, + ) -> Self { + SpillFut { left, right } + } +} + +impl RecordBatchStream for GraceHashJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl GraceHashJoinStream { + pub fn new( + schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, + spill_fut: OnceFut, + spill_left: Arc, + spill_right: Arc, + on_left: Vec, + on_right: Vec, + projection: Option>, + filter: Option, + join_type: JoinType, + column_indices: Vec, + join_metrics: Arc, + context: Arc, + ) -> Self { + Self { + schema, + left_input_schema, + right_input_schema, + spill_fut, + spill_left, + spill_right, + on_left, + on_right, + projection, + filter, + join_type, + column_indices, + join_metrics, + context, + state: GraceJoinState::WaitPartitioning, + } + } + + /// Core state machine logic (poll implementation) + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match &mut self.state { + GraceJoinState::WaitPartitioning => { + let shared = ready!(self.spill_fut.get_shared(cx))?; + let parts = Arc::new(vec![(*shared).clone()]); + self.state = GraceJoinState::JoinPartition { + current: 0, + all_parts: parts, + current_stream: None, + left_fut: None, + right_fut: None, + }; + continue; + } + GraceJoinState::JoinPartition { + current, + all_parts, + current_stream, + left_fut, + right_fut, + } => { + if *current >= all_parts.len() { + self.state = GraceJoinState::Done; + continue; + } + + // If we don't have a stream yet, create one for the current partition pair + if current_stream.is_none() { + if left_fut.is_none() && right_fut.is_none() { + let spill_fut = &all_parts[*current]; + *left_fut = Some(load_partition_async( + Arc::clone(&self.spill_left), + spill_fut.left.clone(), + )); + *right_fut = Some(load_partition_async( + Arc::clone(&self.spill_right), + spill_fut.right.clone(), + )); + } + + let left_batches = + (*ready!(left_fut.as_mut().unwrap().get_shared(cx))?).clone(); + let right_batches = + (*ready!(right_fut.as_mut().unwrap().get_shared(cx))?) + .clone(); + + let stream = build_in_memory_join_stream( + Arc::clone(&self.schema), + Arc::clone(&self.left_input_schema), + Arc::clone(&self.right_input_schema), + left_batches, + right_batches, + &self.on_left, + &self.on_right, + self.projection.clone(), + self.filter.clone(), + self.join_type, + &self.column_indices, + &self.join_metrics, + &self.context, + )?; + + *current_stream = Some(stream); + *left_fut = None; + *right_fut = None; + } + + // Drive current stream forward + if let Some(stream) = current_stream { + match ready!(stream.poll_next_unpin(cx)) { + Some(Ok(batch)) => return Poll::Ready(Some(Ok(batch))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => { + *current += 1; + *current_stream = None; + continue; + } + } + } + } + GraceJoinState::Done => return Poll::Ready(None), + } + } + } +} + +fn load_partition_async( + spill_manager: Arc, + partitions: Vec, +) -> OnceFut> { + OnceFut::new(async move { + let mut all_batches = Vec::new(); + + for p in partitions { + for chunk in p.chunks { + let mut reader = spill_manager.load_spilled_batch(&chunk)?; + while let Some(batch_result) = reader.next().await { + let batch = batch_result?; + all_batches.push(batch); + } + } + } + Ok(all_batches) + }) +} + +/// Build an in-memory HashJoinExec for one pair of spilled partitions +fn build_in_memory_join_stream( + output_schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, + left_batches: Vec, + right_batches: Vec, + on_left: &[PhysicalExprRef], + on_right: &[PhysicalExprRef], + projection: Option>, + filter: Option, + join_type: JoinType, + _column_indices: &[ColumnIndex], + _join_metrics: &BuildProbeJoinMetrics, + context: &Arc, +) -> Result { + if left_batches.is_empty() && right_batches.is_empty() { + return EmptyExec::new(output_schema).execute(0, Arc::clone(context)); + } + + // Build memory execution nodes for each side + let left_plan: Arc = Arc::new(TestMemoryExec::try_new( + &[left_batches], + left_input_schema, + None, + )?); + let right_plan: Arc = Arc::new(TestMemoryExec::try_new( + &[right_batches], + right_input_schema, + None, + )?); + + // Combine join expressions into pairs + let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = on_left + .iter() + .cloned() + .zip(on_right.iter().cloned()) + .collect(); + + // For one partition pair: always CollectLeft (build left, stream right) + let join_exec = HashJoinExec::try_new( + left_plan, + right_plan, + on, + filter, + &join_type, + projection, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNothing, + )?; + + // Each join executes locally with the same context + join_exec.execute(0, Arc::clone(context)) +} + +impl Stream for GraceHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index cb697d460995..c6eaa3489cba 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -21,6 +21,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; use std::{any::Any, vec}; +use crate::coalesce_partitions::CoalescePartitionsExec; use crate::execution_plan::{boundedness_from_children, EmissionType}; use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, @@ -49,7 +50,7 @@ use crate::{ need_produce_result_in_final, symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMapType, }, - metrics::{ExecutionPlanMetricsSet, MetricsSet}, + metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics}, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, }; @@ -84,12 +85,25 @@ use parking_lot::Mutex; const HASH_JOIN_SEED: RandomState = RandomState::with_seeds('J' as u64, 'O' as u64, 'I' as u64, 'N' as u64); +/// Maximum number of partitions allowed when recursively repartitioning during hybrid hash join. +const HYBRID_HASH_MAX_PARTITIONS: usize = 1 << 16; +/// Upper bound multiplier applied to the initial partition fanout when searching for additional partitions. +const HYBRID_HASH_PARTITION_GROWTH_FACTOR: usize = 16; +/// Approximate number of probe batches worth of rows we target per partition when statistics are available. +const HYBRID_HASH_ROWS_PER_PARTITION_BATCH_MULTIPLIER: usize = 8; +/// Minimum number of bytes we aim to keep in memory per partition when deriving the initial fanout. +const HYBRID_HASH_MIN_BYTES_PER_PARTITION: usize = 8 * 1024 * 1024; +/// Minimum number of rows per partition when statistics are available to avoid extreme fan-out. +const HYBRID_HASH_MIN_ROWS_PER_PARTITION: usize = 1_024; + /// HashTable and input data for the left (build side) of a join pub(super) struct JoinLeftData { /// The hash table with indices into `batch` pub(super) hash_map: Box, /// The input rows for the build side batch: RecordBatch, + /// Original build-side batches before concatenation + original_batches: Arc>, /// The build side on expressions values values: Vec, /// Shared bitmap builder for visited left indices @@ -111,6 +125,7 @@ impl JoinLeftData { pub(super) fn new( hash_map: Box, batch: RecordBatch, + original_batches: Arc>, values: Vec, visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, @@ -120,6 +135,7 @@ impl JoinLeftData { Self { hash_map, batch, + original_batches, values, visited_indices_bitmap, probe_threads_counter, @@ -138,6 +154,10 @@ impl JoinLeftData { &self.batch } + pub(super) fn original_batches(&self) -> &[RecordBatch] { + &self.original_batches + } + /// returns a reference to the build side expressions values pub(super) fn values(&self) -> &[ArrayRef] { &self.values @@ -571,13 +591,19 @@ impl HashJoinExec { mode: PartitionMode, projection: Option<&Vec>, ) -> Result { - // Calculate equivalence properties: + // Calculate equivalence properties. For the spillable path, do not claim + // any input order preservation to avoid incorrect planner assumptions + // (e.g., SortPreservingMerge) when the operator may perturb order. + let maintains = match mode { + PartitionMode::PartitionedSpillable => vec![false, false], + _ => Self::maintains_input_order(join_type), + }; let mut eq_properties = join_equivalence_properties( left.equivalence_properties().clone(), right.equivalence_properties().clone(), &join_type, Arc::clone(&schema), - &Self::maintains_input_order(join_type), + &maintains, Some(Self::probe_side()), on, )?; @@ -592,6 +618,11 @@ impl HashJoinExec { PartitionMode::Partitioned => { symmetric_join_output_partitioning(left, right, &join_type)? } + PartitionMode::PartitionedSpillable => { + // Report output partitions consistent with the right side to enable + // proper upstream planning (e.g., repartitioning and aggregations) + Partitioning::UnknownPartitioning(1) + } }; let emission_type = if left.boundedness().is_unbounded() { @@ -797,6 +828,11 @@ impl ExecutionPlan for HashJoinExec { Distribution::UnspecifiedDistribution, Distribution::UnspecifiedDistribution, ], + PartitionMode::PartitionedSpillable => vec![ + // While stabilizing, do not require specific input distributions + Distribution::UnspecifiedDistribution, + Distribution::UnspecifiedDistribution, + ], } } @@ -888,6 +924,7 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { + //println!("Executing HashJoinExec"); let on_left = self .on .iter() @@ -956,6 +993,319 @@ impl ExecutionPlan for HashJoinExec { PartitionMode::Auto ); } + PartitionMode::PartitionedSpillable => { + let enable_spillable = context + .session_config() + .options() + .optimizer + .enable_spillable_hash_join; + + if !enable_spillable { + // Legacy fallback: behave like Partitioned + let left_stream = + self.left.execute(partition, Arc::clone(&context))?; + let reservation = + MemoryConsumer::new(format!("HashJoinInput[{partition}]")) + .register(context.memory_pool()); + OnceFut::new(collect_left_input( + self.random_state.clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, + need_produce_result_in_final(self.join_type), + 1, + enable_dynamic_filter_pushdown, + )) + } else { + // Spillable enabled: coalesce left to a single stream + let left_plan: Arc = + if self.left.output_partitioning().partition_count() == 1 { + Arc::clone(&self.left) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.left))) + }; + let build_schema = left_plan.schema(); + let left_stream = left_plan.execute(0, Arc::clone(&context))?; + let reservation = MemoryConsumer::new("HashJoinInput") + .register(context.memory_pool()); + let left_fut = self.left_fut.try_once(|| { + Ok(collect_left_input( + self.random_state.clone(), + left_stream, + on_left.clone(), + join_metrics.clone(), + reservation, + need_produce_result_in_final(self.join_type), + self.right().output_partitioning().partition_count(), + enable_dynamic_filter_pushdown, + )) + })?; + + let make_bounds_accumulator = + |right_plan: &Arc| { + if enable_dynamic_filter_pushdown { + self.dynamic_filter.as_ref().map(|df| { + let filter = Arc::clone(&df.filter); + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + Arc::clone(df.bounds_accumulator.get_or_init(|| { + Arc::new( + SharedBoundsAccumulator::new_from_partition_mode( + self.mode, + left_plan.as_ref(), + right_plan.as_ref(), + filter, + on_right, + ), + ) + })) + }) + } else { + None + } + }; + + // For Right-side oriented joins, fall back to standard HashJoinStream for correctness + if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { + let right_plan: Arc = + if self.right.output_partitioning().partition_count() == 1 { + Arc::clone(&self.right) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone( + &self.right, + ))) + }; + let right_stream = right_plan.execute(0, Arc::clone(&context))?; + let shared_bounds_accumulator = + make_bounds_accumulator(&right_plan); + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), + context.session_config().batch_size(), + vec![], + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + ))); + } + + use crate::joins::hash_join::partitioned::PartitionedHashJoinStream; + let right_plan: Arc = + if self.right.output_partitioning().partition_count() == 1 { + Arc::clone(&self.right) + } else { + Arc::new(CoalescePartitionsExec::new(Arc::clone(&self.right))) + }; + let right_stream = right_plan.execute(0, Arc::clone(&context))?; + let shared_bounds_accumulator = make_bounds_accumulator(&right_plan); + let probe_schema = right_plan.schema(); + let column_indices_after_projection = match &self.projection { + Some(projection) => projection + .iter() + .map(|i| self.column_indices[*i].clone()) + .collect(), + None => self.column_indices.clone(), + }; + let on_right = self + .on + .iter() + .map(|(_, right_expr)| Arc::clone(right_expr)) + .collect::>(); + let session_config = context.session_config(); + let batch_size = session_config.batch_size(); + let memory_threshold = { + let bytes = session_config + .options() + .execution + .sort_spill_reservation_bytes; + if bytes == 0 { + 1024 * 1024 * 1024 + } else { + bytes + } + }; + let existing_partitions = std::cmp::max( + 1, + right_plan.output_partitioning().partition_count(), + ); + let target_partitions = std::cmp::max( + existing_partitions, + session_config.target_partitions(), + ); + let mut num_partitions = target_partitions; + let mut bytes_per_partition = memory_threshold / 2; + bytes_per_partition = bytes_per_partition + .max(HYBRID_HASH_MIN_BYTES_PER_PARTITION) + .min(memory_threshold.max(HYBRID_HASH_MIN_BYTES_PER_PARTITION)); + + let initial_cap = target_partitions + .saturating_mul(HYBRID_HASH_PARTITION_GROWTH_FACTOR) + .min(HYBRID_HASH_MAX_PARTITIONS); + + let mut build_size_bytes: Option = None; + if let Ok(left_stats) = self.left.partition_statistics(None) { + if let Some(total_bytes) = left_stats.total_byte_size.get_value() + { + let total_bytes = *total_bytes; + build_size_bytes = Some(total_bytes); + if total_bytes <= memory_threshold { + num_partitions = 1; + } else if bytes_per_partition > 0 { + let required = total_bytes + .saturating_add(bytes_per_partition - 1) + / bytes_per_partition; + if required > 0 { + num_partitions = + std::cmp::max(num_partitions, required); + } + } + } + + if let Some(num_rows) = left_stats.num_rows.get_value() { + let num_rows = *num_rows; + let min_rows = session_config + .batch_size() + .saturating_mul( + HYBRID_HASH_ROWS_PER_PARTITION_BATCH_MULTIPLIER, + ) + .max(HYBRID_HASH_MIN_ROWS_PER_PARTITION); + if num_partitions > 1 && min_rows > 0 && num_rows > min_rows { + let required = + num_rows.saturating_add(min_rows - 1) / min_rows; + if required > 0 { + num_partitions = + std::cmp::max(num_partitions, required); + } + } + } + } + + if build_size_bytes + .map(|b| b <= memory_threshold) + .unwrap_or(false) + { + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), + batch_size, + vec![], + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + ))); + } + + num_partitions = num_partitions + .min(initial_cap) + .clamp(1, HYBRID_HASH_MAX_PARTITIONS); + if num_partitions > 1 && !num_partitions.is_power_of_two() { + num_partitions = num_partitions + .checked_next_power_of_two() + .unwrap_or(num_partitions) + .max(1); + } + + if num_partitions == 1 { + return Ok(Box::pin(HashJoinStream::new( + partition, + self.schema(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + self.random_state.clone(), + join_metrics, + column_indices_after_projection, + self.null_equality, + HashJoinStreamState::WaitBuildSide, + BuildSide::Initial(BuildSideInitialState { left_fut }), + batch_size, + vec![], + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + ))); + } + + let mut max_partition_count = if num_partitions == 1 { + 1 + } else { + num_partitions + .saturating_mul(HYBRID_HASH_PARTITION_GROWTH_FACTOR) + .min(HYBRID_HASH_MAX_PARTITIONS) + }; + max_partition_count = + max_partition_count.max(initial_cap).max(num_partitions); + let partitioned_reservation = + MemoryConsumer::new("PartitionedHashJoin") + .register(context.memory_pool()); + let probe_spill_metrics = SpillMetrics::new(&self.metrics, partition); + let build_spill_metrics = SpillMetrics::new(&self.metrics, partition); + let partitioned_stream = PartitionedHashJoinStream::new( + partition, + self.schema(), + on_left.clone(), + on_right, + self.filter.clone(), + self.join_type, + right_stream, + left_fut, + self.random_state.clone(), + join_metrics, + probe_spill_metrics, + build_spill_metrics, + column_indices_after_projection, + self.null_equality, + batch_size, + num_partitions, + max_partition_count, + memory_threshold, + partitioned_reservation, + context.runtime_env(), + build_schema, + probe_schema, + self.right.output_ordering().is_some(), + shared_bounds_accumulator, + )?; + return Ok(Box::pin(partitioned_stream)); + } + } }; let batch_size = context.session_config().batch_size(); @@ -986,7 +1336,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let right_stream = self.right.execute(partition, context)?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; // update column indices to reflect the projection let column_indices_after_projection = match &self.projection { @@ -1391,6 +1741,7 @@ async fn collect_left_input( mut reservation, bounds_accumulators, } = state; + let batches_arc = Arc::new(batches); // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` @@ -1417,7 +1768,7 @@ async fn collect_left_input( let mut offset = 0; // Updating hashmap starting from the last batch - let batches_iter = batches.iter().rev(); + let batches_iter = batches_arc.iter().rev(); for batch in batches_iter.clone() { hashes_buffer.clear(); hashes_buffer.resize(batch.num_rows(), 0); @@ -1472,6 +1823,7 @@ async fn collect_left_input( let data = JoinLeftData::new( hashmap, single_batch, + Arc::clone(&batches_arc), left_values.clone(), Mutex::new(visited_indices_bitmap), AtomicUsize::new(probe_threads_count), @@ -1638,6 +1990,10 @@ mod tests { PartitionMode::Auto => { return internal_err!("Unexpected PartitionMode::Auto in join tests") } + PartitionMode::PartitionedSpillable => Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(left_expr, partition_count), + )?), }; let right_repartitioned: Arc = match partition_mode { @@ -1659,6 +2015,10 @@ mod tests { PartitionMode::Auto => { return internal_err!("Unexpected PartitionMode::Auto in join tests") } + PartitionMode::PartitionedSpillable => Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(right_expr, partition_count), + )?), }; let join = HashJoinExec::try_new( @@ -1788,6 +2148,63 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_spillable_join_inner_one(batch_size: usize) -> Result<()> { + // Configure tiny spill reservation to force spill and 4 partitions + let session_config = SessionConfig::default() + .with_batch_size(batch_size) + .with_target_partitions(4) + .with_sort_spill_reservation_bytes(1) + .with_spill_compression( + datafusion_common::config::SpillCompression::Uncompressed, + ); + let task_ctx = + Arc::new(TaskContext::default().with_session_config(session_config)); + + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches, metrics) = join_collect_with_partition_mode( + Arc::clone(&left), + Arc::clone(&right), + on, + &JoinType::Inner, + PartitionMode::PartitionedSpillable, + NullEquality::NullEqualsNothing, + task_ctx, + ) + .await?; + + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); + + let expected = [ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + assert_join_metrics!(metrics, 3); + Ok(()) + } + #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); @@ -4500,4 +4917,77 @@ mod tests { fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } + + #[tokio::test] + async fn partitioned_spillable_spills_to_disk() -> Result<()> { + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + // Force spilling with very low reservation; single partition correctness path + let session_config = SessionConfig::default() + .with_batch_size(1024) + .with_target_partitions(1) + .with_sort_spill_reservation_bytes(1) + .with_spill_compression( + datafusion_common::config::SpillCompression::Uncompressed, + ); + let runtime = RuntimeEnvBuilder::new().build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + // Build left/right to ensure build side has more than 1 row to trigger spill partitioning + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8]), + ("b1", &vec![1, 1, 1, 1, 1, 1, 1, 1]), + ("c1", &vec![0, 0, 0, 0, 0, 0, 0, 0]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![1, 1, 1, 2]), + ("c2", &vec![0, 0, 0, 0]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + // Execute with PartitionedSpillable + let join = HashJoinExec::try_new( + Arc::clone(&left), + Arc::clone(&right), + on, + None, + &JoinType::Inner, + None, + PartitionMode::PartitionedSpillable, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, Arc::clone(&task_ctx))?; + // Collect all batches to drive execution and spill + let _ = common::collect(stream).await?; + + // Assert that spilling occurred by inspecting metrics on the operator + let metrics = join.metrics().unwrap(); + // Find any spill metrics in the tree and ensure spilled_rows > 0 + let mut spilled_any = false; + for m in metrics.iter() { + let name = m.value().name(); + let v = m.value().as_usize(); + if (name == "spilled_rows" + || name == "spilled_bytes" + || name == "spill_count") + && v > 0 + { + spilled_any = true; + break; + } + } + assert!( + spilled_any, + "expected spilling to occur in PartitionedSpillable mode" + ); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 7f1e5cae13a3..9d70ca3e1ac6 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -20,5 +20,8 @@ pub use exec::HashJoinExec; mod exec; -mod shared_bounds; +mod partitioned; +#[cfg(feature = "hybrid_hash_join_scheduler")] +mod scheduler; +pub mod shared_bounds; mod stream; diff --git a/datafusion/physical-plan/src/joins/hash_join/partitioned.rs b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs new file mode 100644 index 000000000000..a4187c9d355b --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/partitioned.rs @@ -0,0 +1,3647 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Partitioned Hash Join implementation +//! +//! This module implements a partitioned hash join that can handle large datasets +//! by partitioning both build and probe sides into multiple partitions and +//! processing them sequentially. This approach is similar to sort-merge join +//! but uses hash-based partitioning instead of sorting. +//! +//! # State Machine Overview +//! +//! The partitioned hash join follows this state machine pattern: +//! +//! ```text +//! PartitionBuildSide → ProcessPartitions(i) → Done +//! ``` +//! +//! ## PartitionBuildSide State +//! - Partitions build-side data into multiple partitions based on hash values +//! - Keeps one partition resident in memory (partition 0) +//! - Spills other partitions to disk when memory pressure occurs +//! - Uses consistent hashing to ensure same keys go to same partition +//! +//! ## ProcessPartitions State +//! - Processes each partition sequentially +//! - Loads build-side hash map for current partition (from memory or disk) +//! - Probes all probe batches for this partition against the hash map +//! - Generates join results and handles unmatched rows for outer joins +//! - Tracks matched rows for proper outer join semantics + +#[cfg(feature = "hybrid_hash_join_scheduler")] +use super::scheduler::{ + HybridTaskScheduler, ProbeDataPoll, ProbePartitionState, ProbeStageTask, + SchedulerTask, TaskPoll, +}; +use crate::joins::hash_join::exec::JoinLeftData; +use crate::joins::join_hash_map::{JoinHashMapType, JoinHashMapU32, JoinHashMapU64}; +use crate::joins::utils::{ + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + equal_rows_arr, get_final_indices_from_bit_map, need_produce_result_in_final, + uint32_to_uint64_indices, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceFut, + StatefulStreamResult, +}; +use crate::metrics::SpillMetrics; +use crate::spill::in_progress_spill_file::InProgressSpillFile; +use crate::spill::spill_manager::SpillManager; +use crate::{RecordBatchStream, SendableRecordBatchStream}; +use std::collections::VecDeque; +use std::mem::{self, size_of}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{Array, ArrayRef, BooleanBufferBuilder, UInt32Array, UInt64Array}; +use arrow::compute::{concat_batches, take}; +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::utils::memory::estimate_memory_size; +use datafusion_common::{ + hash_utils::create_hashes, internal_datafusion_err, internal_err, DataFusionError, + JoinSide, JoinType, NullEquality, Result, +}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_physical_expr::PhysicalExprRef; + +use ahash::RandomState; +use futures::{executor::block_on, ready, Stream, StreamExt}; + +const HYBRID_HASH_MAX_REPARTITION_DEPTH: usize = 6; +const HYBRID_HASH_MIN_FANOUT: usize = 2; +const HYBRID_HASH_MIN_PARTITION_BYTES: usize = 8 * 1024 * 1024; +const HYBRID_HASH_ROWS_PER_PARTITION_TARGET_MULTIPLIER: usize = 8; +const HYBRID_HASH_ROWS_PER_PARTITION_MIN: usize = 32 * 1024; + +fn highest_power_of_two_leq(n: usize) -> usize { + if n <= 1 { + 1 + } else { + let mut power = 1usize; + while (power << 1) <= n { + power <<= 1; + } + power + } +} + +fn max_partitions_allowed_for_memory(memory_threshold: usize) -> usize { + let mut slots = memory_threshold + .checked_div(HYBRID_HASH_MIN_PARTITION_BYTES) + .unwrap_or(usize::MAX); + if slots == 0 { + slots = 1; + } + highest_power_of_two_leq(slots) +} + +fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usize { + let partitions = partitions.max(1); + let mut budget = memory_threshold + .checked_div(partitions) + .unwrap_or(memory_threshold); + if budget == 0 { + budget = HYBRID_HASH_MIN_PARTITION_BYTES; + } + budget.max(HYBRID_HASH_MIN_PARTITION_BYTES) +} + +/// State of the partitioned hash join stream +#[derive(Debug, Clone)] +pub(super) enum PartitionedHashJoinState { + /// Initial state - partitioning build side + PartitionBuildSide, + /// Processing a specific partition + ProcessPartition(ProcessPartitionState), + /// Waiting for partitions that are throttled on probe IO to resume + #[cfg(feature = "hybrid_hash_join_scheduler")] + WaitingForProbe, + /// All partitions processed, handling unmatched rows for outer joins + HandleUnmatchedRows, + /// Join completed + Completed, +} + +/// State for processing a specific partition +#[derive(Debug, Clone)] +pub(super) struct ProcessPartitionState { + /// Descriptor for the partition currently being processed + descriptor: PartitionDescriptor, +} + +/// Represents a partition of build-side data +pub(super) enum BuildPartition { + /// Partition data in memory + InMemory { + /// Hash map for this partition + hash_map: Box, + /// Build-side batch data + batch: RecordBatch, + /// Join key values + values: Vec, + /// Memory reservation for this partition + reservation: MemoryReservation, + }, + /// Partition data spilled to disk + Spilled { + /// Spill file containing the partition data (taken on load) + spill_file: Option, + /// Memory reservation (released when spilled) + reservation: MemoryReservation, + /// Total bytes written for this spill partition + spilled_bytes: usize, + /// Total rows written for this spill partition + spilled_rows: usize, + }, + /// Partition resources released and not available + Released { + /// Placeholder reservation + reservation: MemoryReservation, + }, + /// Empty partition (no rows) + Empty, +} + +/// Represents a partition of probe-side data +#[derive(Debug)] +pub(super) struct ProbePartition { + /// Batches in this partition + pub batches: Vec, + /// Join key values for each batch + pub values: Vec>, + /// Hash values for each batch + pub hashes: Vec>, +} + +impl ProbePartition { + pub(super) fn new() -> Self { + Self { + batches: Vec::new(), + values: Vec::new(), + hashes: Vec::new(), + } + } +} + +/// Runtime state tracked per probe partition. +#[cfg(not(feature = "hybrid_hash_join_scheduler"))] +#[cfg(not(feature = "hybrid_hash_join_scheduler"))] +pub(super) struct ProbePartitionState { + buffered: ProbePartition, + batch_position: usize, + buffered_rows: usize, + spilled_rows: usize, + consumed_rows: usize, + spill_in_progress: Option, + spill_files: VecDeque, + pending_stream: Option, + active_batch: Option, + active_values: Vec, + active_hashes: Vec, + active_offset: crate::joins::join_hash_map::JoinHashMapOffset, + joined_probe_idx: Option, +} + +#[cfg(not(feature = "hybrid_hash_join_scheduler"))] +impl ProbePartitionState { + fn new() -> Self { + Self { + buffered: ProbePartition::new(), + batch_position: 0, + buffered_rows: 0, + spilled_rows: 0, + consumed_rows: 0, + spill_in_progress: None, + spill_files: VecDeque::new(), + pending_stream: None, + active_batch: None, + active_values: Vec::new(), + active_hashes: Vec::new(), + active_offset: (0, None), + joined_probe_idx: None, + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn prepare_probe_values( + &self, + batch: &RecordBatch, + ) -> Result<(Vec, Vec)> { + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + keys_values.push(c.evaluate(batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + Ok((keys_values, hashes)) + } + + fn reset(&mut self) { + *self = Self::new(); + } +} + +enum PartitionBuildStatus { + Ready(StatefulStreamResult>), + NeedMorePartitions { next_count: usize }, +} + +struct PartitionAccumulator { + buffered_batches: Vec, + buffered_bytes: usize, + total_rows: usize, + spill_writer: Option, + spilled_bytes: usize, +} + +impl PartitionAccumulator { + fn new() -> Self { + Self { + buffered_batches: Vec::new(), + buffered_bytes: 0, + total_rows: 0, + spill_writer: None, + spilled_bytes: 0, + } + } +} + +impl Default for PartitionAccumulator { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, Clone)] +pub(super) struct PartitionDescriptor { + /// Index into build/probe storage vectors + pub(super) build_index: usize, + /// Index of the original (generation 0) partition + root_index: usize, + /// Number of refinement passes applied so far + generation: usize, + /// Total number of radix bits used to identify this partition + radix_bits: usize, + /// Hash prefix (lower `radix_bits`) identifying this partition + hash_prefix: u64, + /// Latest spilled byte estimate for this partition + spilled_bytes: usize, + /// Latest spilled row estimate for this partition + spilled_rows: usize, +} + +// Use RefCountedTempFile from datafusion_execution::disk_manager + +/// Partitioned Hash Join stream that can handle large datasets by partitioning +/// both build and probe sides and processing them sequentially. +pub(super) struct PartitionedHashJoinStream { + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== + /// Partition identifier for debugging and determinism + pub partition: usize, + /// Output schema + pub schema: SchemaRef, + /// Join key columns from the right (probe side) + pub on_right: Vec, + /// Join key columns from the left (build side) + pub on_left: Vec, + /// Optional join filter + pub filter: Option, + /// Type of the join (left, right, semi, etc) + pub join_type: JoinType, + /// Right (probe) input stream + pub right: SendableRecordBatchStream, + /// Future that yields the collected build-side data + pub left_fut: OnceFut, + /// Random state used for hashing initialization + pub random_state: RandomState, + /// Metrics + pub join_metrics: BuildProbeJoinMetrics, + /// Information of index and left / right placement of columns + pub column_indices: Vec, + /// Defines the null equality for the join + pub null_equality: NullEquality, + /// Maximum output batch size + pub batch_size: usize, + /// Number of partitions to use + pub num_partitions: usize, + /// Maximum partition fanout allowed when recursively repartitioning + pub max_partition_count: usize, + /// Memory threshold for spilling (in bytes) + pub memory_threshold: usize, + + // ======================================================================== + // STATE: + // These fields track the execution state and are updated during execution. + // ======================================================================== + /// Current state of the stream + pub state: PartitionedHashJoinState, + /// Build-side partitions + pub build_partitions: Vec, + /// Probe-side partitions + pub probe_states: Vec, + /// Scheduler used to coordinate probe tasks + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_task_scheduler: HybridTaskScheduler, + /// Whether a scheduler task is currently in-flight per partition + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_inflight: Vec, + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_waiting_for_stream: VecDeque, + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_active_streams: usize, + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub probe_scheduler_max_streams: usize, + /// Current partition being processed + pub current_partition: Option, + /// Queue of pending partitions to process (supports recursive fan-out) + pub pending_partitions: VecDeque, + /// Spill manager for probe-side (right) batches + pub probe_spill_manager: SpillManager, + /// Spill manager for build-side (left) batches + pub build_spill_manager: SpillManager, + /// Memory reservation for the entire operation + pub memory_reservation: MemoryReservation, + /// Tracks how many repartition passes have been attempted + pub partition_pass: usize, + /// Indicates whether the current pass has already prepared partitions for output + pub partition_pass_output_started: bool, + /// Runtime environment + pub runtime_env: Arc, + /// Scratch space for computing hashes + pub hashes_buffer: Vec, + /// Whether the right side has an ordering to potentially preserve + pub right_side_ordered: bool, + /// Whether this stream has emitted a placeholder batch for downstream scheduling + pub placeholder_emitted: bool, + /// Running alignment start for right indices across probe batches (for semi/anti/mark) + pub right_alignment_start: usize, + /// Shared bounds accumulator for coordinating dynamic filter updates (optional) + pub bounds_accumulator: + Option>, + /// Future used to synchronize dynamic filter updates across partitions + pub bounds_waiter: Option>, + /// Cached build-side schema + pub build_schema: SchemaRef, + /// Cached probe-side schema + pub probe_schema: SchemaRef, + /// Bitmaps to track matched build-side rows for outer joins (one per partition) + pub matched_build_rows_per_partition: Vec, + /// Current partition being processed for unmatched rows + pub unmatched_partition: usize, + /// Cached unmatched build/probe indices for current partition (chunked emission) + pub unmatched_left_indices_cache: Option, + pub unmatched_right_indices_cache: Option, + pub unmatched_offset: usize, + /// Whether the probe stream has reached EOF + pub probe_stream_finished: bool, + /// Metrics: total matches after equality per partition + pub matched_rows_per_part: Vec, + /// Metrics: total rows emitted per partition + pub emitted_rows_per_part: Vec, + /// Metrics: total candidate pairs before equality per partition + pub candidate_pairs_per_part: Vec, + /// Pending async spill reload stream for build partitions + pub pending_reload_stream: Option, + /// Accumulated batches for pending reload + pub pending_reload_batches: Vec, + /// Target partition id for pending reload + pub pending_reload_partition: Option, + /// Whether a partition is currently queued for processing + pub partition_pending: Vec, + /// Latest descriptor metadata per partition + pub partition_descriptors: Vec>, +} + +#[cfg(feature = "hybrid_hash_join_scheduler")] +#[derive(Debug)] +enum ProbeTaskStatus { + Ready, + Pending, + WaitingForStream, + Finished, +} + +impl PartitionedHashJoinStream { + /// Compute partition id for a given hash using radix mask when possible + #[inline] + fn partition_for_hash(&self, hash: u64) -> usize { + if self.num_partitions.is_power_of_two() { + (hash as usize) & (self.num_partitions - 1) + } else { + // Fallback when num_partitions is not a power of two + (hash as usize) % self.num_partitions + } + } + + fn resize_partition_vectors(&mut self) { + let n = self.num_partitions; + self.probe_states = (0..n).map(|_| ProbePartitionState::new()).collect(); + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + self.probe_scheduler_inflight = vec![false; n]; + self.probe_scheduler_waiting_for_stream = VecDeque::new(); + self.probe_scheduler_active_streams = 0; + self.probe_scheduler_max_streams = std::cmp::max(1, std::cmp::min(4, n)); + self.probe_task_scheduler = HybridTaskScheduler::new(); + } + self.matched_rows_per_part = vec![0; n]; + self.emitted_rows_per_part = vec![0; n]; + self.candidate_pairs_per_part = vec![0; n]; + self.partition_pending = vec![false; n]; + self.partition_descriptors = (0..n).map(|_| None).collect(); + } + + fn probe_state(&self, idx: usize) -> Result<&ProbePartitionState> { + self.probe_states + .get(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition")) + } + + fn probe_state_mut(&mut self, idx: usize) -> Result<&mut ProbePartitionState> { + self.probe_states + .get_mut(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition")) + } + + fn allocate_partition_slot(&mut self) -> usize { + let idx = self.build_partitions.len(); + self.build_partitions.push(BuildPartition::Empty); + self.matched_build_rows_per_partition + .push(BooleanBufferBuilder::new(0)); + self.probe_states.push(ProbePartitionState::new()); + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + self.probe_scheduler_inflight.push(false); + } + self.matched_rows_per_part.push(0); + self.emitted_rows_per_part.push(0); + self.candidate_pairs_per_part.push(0); + self.partition_pending.push(false); + self.partition_descriptors.push(None); + idx + } + + fn schedule_partition(&mut self, part_id: usize) -> Result<()> { + if part_id >= self.partition_pending.len() { + let new_len = part_id + 1; + self.partition_pending.resize(new_len, false); + self.partition_descriptors.resize_with(new_len, || None); + } + + if self.current_partition == Some(part_id) { + return Ok(()); + } + + if self.partition_pending[part_id] { + return Ok(()); + } + + if let Some(desc) = self + .partition_descriptors + .get(part_id) + .and_then(|d| d.clone()) + { + self.pending_partitions.push_back(desc.clone()); + self.partition_pending[part_id] = true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); + } + + Ok(()) + } + + fn flush_probe_writer( + &mut self, + part_id: usize, + ) -> Result> { + if let Some(state) = self.probe_states.get_mut(part_id) { + if let Some(mut writer) = state.spill_in_progress.take() { + return writer.finish(); + } + } + Ok(None) + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn ensure_probe_scheduler_capacity(&mut self, part_id: usize) { + if self.probe_scheduler_inflight.len() <= part_id { + self.probe_scheduler_inflight.resize(part_id + 1, false); + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn schedule_probe_task(&mut self, descriptor: &PartitionDescriptor) { + let part_id = descriptor.build_index; + self.ensure_probe_scheduler_capacity(part_id); + if self.probe_scheduler_inflight[part_id] { + return; + } + let task = SchedulerTask::Probe(ProbeStageTask::new(descriptor.clone())); + self.probe_task_scheduler.push_task(task); + self.probe_scheduler_inflight[part_id] = true; + } + + fn finalize_spilled_partition(&mut self, part_id: usize) -> Result { + if part_id >= self.probe_states.len() { + return Ok(false); + } + if let Some(file) = self.flush_probe_writer(part_id)? { + if let Some(state) = self.probe_states.get_mut(part_id) { + state.spill_files.push_back(file); + } + self.schedule_partition(part_id)?; + return Ok(true); + } + Ok(false) + } + + fn compute_recursive_fanout( + &self, + descriptor: &PartitionDescriptor, + ) -> Option<(usize, usize)> { + if descriptor.generation >= HYBRID_HASH_MAX_REPARTITION_DEPTH { + return None; + } + if self.max_partition_count == 0 { + return None; + } + let current_total = self.build_partitions.len(); + if current_total == 0 { + return None; + } + + let max_fanout_allowed = self + .max_partition_count + .saturating_sub(current_total.saturating_sub(1)); + if max_fanout_allowed < HYBRID_HASH_MIN_FANOUT { + return None; + } + + let per_partition_budget = + per_partition_budget_bytes(self.memory_threshold, self.num_partitions); + + let rows_budget = self + .batch_size + .saturating_mul(HYBRID_HASH_ROWS_PER_PARTITION_TARGET_MULTIPLIER) + .max(HYBRID_HASH_ROWS_PER_PARTITION_MIN); + + let should_repartition_bytes = descriptor.spilled_bytes > per_partition_budget; + let should_repartition_rows = descriptor.spilled_rows > rows_budget; + + if !should_repartition_bytes && !should_repartition_rows { + return None; + } + + let mut required = HYBRID_HASH_MIN_FANOUT; + + if should_repartition_bytes { + let budget = per_partition_budget.max(1); + let needed = descriptor.spilled_bytes.saturating_add(budget - 1) / budget; + required = required.max(needed); + } + + if should_repartition_rows { + let budget = rows_budget.max(1); + let needed = descriptor.spilled_rows.saturating_add(budget - 1) / budget; + required = required.max(needed); + } + + let mut fanout = required.next_power_of_two(); + if fanout == 0 { + fanout = HYBRID_HASH_MIN_FANOUT; + } + if fanout > max_fanout_allowed { + fanout = highest_power_of_two_leq(max_fanout_allowed); + } + if fanout < HYBRID_HASH_MIN_FANOUT { + return None; + } + + let additional_bits = fanout.trailing_zeros() as usize; + if additional_bits == 0 { + return None; + } + Some((additional_bits, fanout)) + } + + fn repartition_spilled_partition( + &mut self, + descriptor: &PartitionDescriptor, + additional_bits: usize, + fanout: usize, + ) -> Result> { + let build_index = descriptor.build_index; + if build_index >= self.build_partitions.len() { + return Ok(vec![]); + } + + let placeholder_reservation = + MemoryConsumer::new("partition_repartition_placeholder") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + let old_partition = mem::replace( + &mut self.build_partitions[build_index], + BuildPartition::Released { + reservation: placeholder_reservation, + }, + ); + + let (spill_file, _spilled_bytes, _spilled_rows) = match old_partition { + BuildPartition::Spilled { + spill_file, + spilled_bytes, + spilled_rows, + .. + } => ( + spill_file.ok_or_else(|| { + internal_datafusion_err!( + "spill file already consumed for partition {}", + build_index + ) + })?, + spilled_bytes, + spilled_rows, + ), + other => { + self.build_partitions[build_index] = other; + return Ok(vec![]); + } + }; + + // Collect spilled build batches + let mut build_batches = block_on(async { + let mut stream = self.build_spill_manager.read_spill_as_stream(spill_file)?; + let mut batches = Vec::new(); + while let Some(batch) = stream.next().await { + batches.push(batch?); + } + Result::>::Ok(batches) + })?; + + if build_batches.is_empty() { + // Nothing to repartition; keep placeholder as empty partition + let mut new_descriptor = descriptor.clone(); + new_descriptor.spilled_bytes = 0; + new_descriptor.spilled_rows = 0; + self.matched_build_rows_per_partition[build_index] = + BooleanBufferBuilder::new(0); + self.build_partitions[build_index] = BuildPartition::Empty; + return Ok(vec![new_descriptor]); + } + + let shift_bits = descriptor.radix_bits; + let mask = (fanout - 1) as u64; + let mut sub_accumulators = (0..fanout) + .map(|_| PartitionAccumulator::new()) + .collect::>(); + + self.join_metrics.recursive_repartition_events.add(1); + self.join_metrics.recursive_partitions_created.add(fanout); + self.join_metrics + .recursive_partition_depth + .set_max(descriptor.generation.saturating_add(1)); + self.join_metrics + .recursive_repartition_fanout + .set_max(fanout); + + for batch in build_batches.drain(..) { + let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); + for expr in &self.on_left { + keys_values.push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = vec![Vec::new(); fanout]; + for (row_idx, hash) in hashes.iter().enumerate() { + let sub_idx = (((*hash >> shift_bits) as usize) & mask as usize) % fanout; + indices_per_part[sub_idx].push(row_idx as u32); + } + + for (sub_idx, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + let idx_array = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &idx_array, None).map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + let batch_size = filtered_batch.get_array_memory_size(); + + let accum = &mut sub_accumulators[sub_idx]; + accum.total_rows += filtered_batch.num_rows(); + + match self.memory_reservation.try_grow(batch_size) { + Ok(_) => { + accum.buffered_bytes += batch_size; + accum.buffered_batches.push(filtered_batch); + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + if self.memory_reservation.size() > self.memory_threshold { + self.spill_partition(sub_idx, accum)?; + } + } + Err(_) => { + self.spill_partition(sub_idx, accum)?; + self.append_spilled_batch(accum, filtered_batch)?; + } + } + } + } + + // Finalize sub partitions + let new_radix_bits = descriptor.radix_bits + additional_bits; + let mut new_descriptors = Vec::with_capacity(fanout); + let mut partition_indices = Vec::with_capacity(fanout); + + for sub_idx in 0..fanout { + let accum = &mut sub_accumulators[sub_idx]; + let mut matched_bitmap = BooleanBufferBuilder::new(accum.total_rows); + matched_bitmap.append_n(accum.total_rows, false); + + let new_index = if sub_idx == 0 { + build_index + } else { + self.allocate_partition_slot() + }; + partition_indices.push(new_index); + + self.matched_build_rows_per_partition[new_index] = matched_bitmap; + + if accum.spill_writer.is_some() || !accum.buffered_batches.is_empty() { + if accum.spill_writer.is_some() { + if !accum.buffered_batches.is_empty() { + self.spill_partition(sub_idx, accum)?; + } + let mut writer = accum.spill_writer.take().ok_or_else(|| { + internal_datafusion_err!("missing spill writer") + })?; + let spill_file = writer.finish()?.ok_or_else(|| { + internal_datafusion_err!("expected spill file after repartition") + })?; + let reservation = MemoryConsumer::new("partition_spilled") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + self.build_partitions[new_index] = BuildPartition::Spilled { + spill_file: Some(spill_file), + reservation, + spilled_bytes: accum.spilled_bytes, + spilled_rows: accum.total_rows, + }; + } else { + let mut buffered_batches = mem::take(&mut accum.buffered_batches); + let partition_batch = if buffered_batches.len() == 1 { + buffered_batches.pop().unwrap() + } else { + let batch_refs: Vec<_> = buffered_batches.iter().collect(); + concat_batches(&self.build_schema, batch_refs)? + }; + let num_rows = partition_batch.num_rows(); + let partition_values = self + .on_left + .iter() + .map(|expr| expr.evaluate(&partition_batch)?.into_array(num_rows)) + .collect::>>()?; + + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hash_map: Box = if num_rows + > u32::MAX as usize + { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(num_rows, 0); + create_hashes( + &partition_values, + &self.random_state, + &mut self.hashes_buffer, + )?; + hash_map.extend_zero(num_rows); + let iter = self + .hashes_buffer + .iter() + .enumerate() + .map(|(idx, hash)| (idx, hash)); + hash_map.update_from_iter(Box::new(iter), 0); + + let reservation = MemoryConsumer::new("partition_memory") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + self.build_partitions[new_index] = BuildPartition::InMemory { + hash_map, + batch: partition_batch, + values: partition_values, + reservation, + }; + accum.spilled_bytes = 0; + } + } else { + self.build_partitions[new_index] = BuildPartition::Empty; + } + + let hash_prefix = + (descriptor.hash_prefix << additional_bits) | (sub_idx as u64); + new_descriptors.push(PartitionDescriptor { + build_index: new_index, + root_index: descriptor.root_index, + generation: descriptor.generation + 1, + radix_bits: new_radix_bits, + hash_prefix, + spilled_bytes: accum.spilled_bytes, + spilled_rows: accum.total_rows, + }); + } + + self.repartition_probe_partition(descriptor, fanout, &partition_indices)?; + + Ok(new_descriptors) + } + + fn repartition_probe_partition( + &mut self, + descriptor: &PartitionDescriptor, + fanout: usize, + partition_indices: &[usize], + ) -> Result<()> { + let parent_index = descriptor.build_index; + if parent_index >= self.probe_states.len() { + return Ok(()); + } + + let shift_bits = descriptor.radix_bits; + let mask = (fanout - 1) as u64; + + let spill_file = { + let state = self + .probe_states + .get_mut(parent_index) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.batch_position = 0; + state.buffered_rows = 0; + state.spilled_rows = 0; + state.consumed_rows = 0; + state.active_batch = None; + state.active_values.clear(); + state.active_hashes.clear(); + state.active_offset = (0, None); + state.joined_probe_idx = None; + state.pending_stream = None; + state.spill_files.pop_front() + }; + + if let Some(file) = spill_file { + let mut writers = Vec::with_capacity(fanout); + for _ in 0..fanout { + let writer = self + .probe_spill_manager + .create_in_progress_file("hash_join_probe_repartition")?; + writers.push(writer); + } + + let mut file_opt = Some(file); + block_on(async { + let mut stream = self + .probe_spill_manager + .read_spill_as_stream(file_opt.take().unwrap())?; + while let Some(batch) = stream.next().await { + let batch = batch?; + let mut key_arrays: Vec = + Vec::with_capacity(self.on_right.len()); + for expr in &self.on_right { + key_arrays + .push(expr.evaluate(&batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&key_arrays, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = vec![Vec::new(); fanout]; + for (row_idx, hash) in hashes.iter().enumerate() { + let sub_idx = + (((*hash >> shift_bits) as usize) & mask as usize) % fanout; + indices_per_part[sub_idx].push(row_idx as u32); + } + + for (sub_idx, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + let indices_arr = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + let writer = writers + .get_mut(sub_idx) + .ok_or_else(|| internal_datafusion_err!("missing writer"))?; + writer.append_batch(&filtered_batch)?; + self.join_metrics + .probe_spilled_rows + .add(filtered_batch.num_rows()); + self.join_metrics + .probe_spilled_bytes + .add(filtered_batch.get_array_memory_size()); + } + } + Result::<()>::Ok(()) + })?; + + for (sub_idx, mut writer) in writers.into_iter().enumerate() { + let file = writer.finish()?.ok_or_else(|| { + internal_datafusion_err!("expected probe spill file") + })?; + let partitions_idx = partition_indices[sub_idx]; + let state = self + .probe_states + .get_mut(partitions_idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.spill_files.push_back(file); + state.spilled_rows = 0; + state.buffered_rows = 0; + state.consumed_rows = 0; + state.batch_position = 0; + state.pending_stream = None; + state.active_batch = None; + state.active_values.clear(); + state.active_hashes.clear(); + state.active_offset = (0, None); + state.joined_probe_idx = None; + } + return Ok(()); + } + + // In-memory probe data + let parent_partition = { + let state = self + .probe_states + .get_mut(parent_index) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + mem::replace(&mut state.buffered, ProbePartition::new()) + }; + for idx in 0..parent_partition.batches.len() { + let batch = &parent_partition.batches[idx]; + let values = &parent_partition.values[idx]; + let hashes = &parent_partition.hashes[idx]; + let mut indices_per_part: Vec> = vec![Vec::new(); fanout]; + for (row_idx, hash) in hashes.iter().enumerate() { + let sub_idx = (((*hash >> shift_bits) as usize) & mask as usize) % fanout; + indices_per_part[sub_idx].push(row_idx as u32); + } + + for (sub_idx, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + let indices_arr = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &indices_arr, None).map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + + let mut filtered_values: Vec = Vec::with_capacity(values.len()); + for arr in values.iter() { + filtered_values.push( + take(arr, &indices_arr, None).map_err(DataFusionError::from)?, + ); + } + + let mut filtered_hashes: Vec = Vec::with_capacity(indices_arr.len()); + for i in indices_arr.values().iter() { + filtered_hashes.push(hashes[*i as usize]); + } + + let idx = partition_indices[sub_idx]; + let state = self + .probe_states + .get_mut(idx) + .ok_or_else(|| internal_datafusion_err!("missing probe partition"))?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_values); + state.buffered.hashes.push(filtered_hashes); + let buffered = state + .buffered + .batches + .last() + .map(|b| b.num_rows()) + .unwrap_or_default(); + state.buffered_rows = state.buffered_rows.saturating_add(buffered); + } + } + + Ok(()) + } + + fn buffer_probe_side(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.probe_states.len() != self.num_partitions { + self.resize_partition_vectors(); + } + + loop { + match self.right.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + let mut keys_values: Vec = + Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c.evaluate(&batch)?.into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = + vec![Vec::new(); self.num_partitions]; + for (row_idx, &hash) in hashes.iter().enumerate() { + let pid = self.partition_for_hash(hash) as usize; + indices_per_part[pid].push(row_idx as u32); + } + + for part_id in 0..self.num_partitions { + let part_indices = &indices_per_part[part_id]; + if part_indices.is_empty() { + continue; + } + + let indices_arr: UInt32Array = part_indices.clone().into(); + + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + + let mut filtered_on_values: Vec = + Vec::with_capacity(self.on_right.len()); + for arr in &keys_values { + filtered_on_values.push( + take(arr, &indices_arr, None) + .map_err(DataFusionError::from)?, + ); + } + + let mut filtered_hashes: Vec = + Vec::with_capacity(part_indices.len()); + for &i in part_indices.iter() { + filtered_hashes.push(hashes[i as usize]); + } + + if matches!( + self.build_partitions.get(part_id), + Some(BuildPartition::Spilled { .. }) + ) { + let (queue_ready, stream_active) = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| { + internal_datafusion_err!( + "missing probe partition" + ) + })?; + if state.spill_in_progress.is_none() { + let ipf = self + .probe_spill_manager + .create_in_progress_file( + "hash_join_probe_partition", + )?; + state.spill_in_progress = Some(ipf); + self.join_metrics.probe_spill_count.add(1); + } + if let Some(ref mut ipf) = state.spill_in_progress { + ipf.append_batch(&filtered_batch)?; + self.join_metrics + .probe_spilled_rows + .add(filtered_batch.num_rows()); + self.join_metrics + .probe_spilled_bytes + .add(filtered_batch.get_array_memory_size()); + } + state.spilled_rows = state + .spilled_rows + .saturating_add(filtered_batch.num_rows()); + ( + !state.spill_files.is_empty(), + state.pending_stream.is_some(), + ) + }; + if !queue_ready && !stream_active { + self.finalize_spilled_partition(part_id)?; + } + } else { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing probe partition") + })?; + state.buffered.batches.push(filtered_batch); + state.buffered.values.push(filtered_on_values); + state.buffered.hashes.push(filtered_hashes); + if let Some(last) = state.buffered.batches.last() { + state.buffered_rows = + state.buffered_rows.saturating_add(last.num_rows()); + } + } + } + + return Poll::Ready(Ok(())); + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + self.probe_stream_finished = true; + for part_id in 0..self.num_partitions { + self.finalize_spilled_partition(part_id)?; + } + return Poll::Ready(Ok(())); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + } + + fn maybe_recursive_repartition( + &mut self, + descriptor: &PartitionDescriptor, + ) -> Result { + if descriptor.build_index >= self.build_partitions.len() { + return Ok(false); + } + match self.build_partitions.get(descriptor.build_index) { + Some(BuildPartition::Spilled { .. }) => {} + _ => return Ok(false), + } + let Some((additional_bits, fanout)) = self.compute_recursive_fanout(descriptor) + else { + return Ok(false); + }; + let new_descriptors = + self.repartition_spilled_partition(descriptor, additional_bits, fanout)?; + if new_descriptors.is_empty() { + return Ok(false); + } + // Enqueue new descriptors in order + for desc in new_descriptors.into_iter().rev() { + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); + self.pending_partitions.push_front(desc); + } + Ok(true) + } + + fn ensure_build_spill_writer<'a>( + &self, + accum: &'a mut PartitionAccumulator, + ) -> Result<&'a mut InProgressSpillFile> { + if accum.spill_writer.is_none() { + accum.spill_writer = Some( + self.build_spill_manager + .create_in_progress_file("hash_join_build_partition")?, + ); + } + Ok(accum.spill_writer.as_mut().unwrap()) + } + + fn spill_partition( + &mut self, + _build_index: usize, + accum: &mut PartitionAccumulator, + ) -> Result<()> { + let buffered_batches = mem::take(&mut accum.buffered_batches); + if buffered_batches.is_empty() { + return Ok(()); + } + + let created_writer = accum.spill_writer.is_none(); + let mut total_spilled_bytes = 0usize; + { + let writer = self.ensure_build_spill_writer(accum)?; + if created_writer { + self.join_metrics.build_spill_count.add(1); + } + for batch in buffered_batches { + let batch_size = batch.get_array_memory_size(); + total_spilled_bytes = total_spilled_bytes.saturating_add(batch_size); + self.join_metrics.build_spilled_rows.add(batch.num_rows()); + self.join_metrics.build_spilled_bytes.add(batch_size); + writer.append_batch(&batch)?; + } + } + accum.spilled_bytes = accum.spilled_bytes.saturating_add(total_spilled_bytes); + if accum.buffered_bytes > 0 { + let _ = self.memory_reservation.try_shrink(accum.buffered_bytes); + accum.buffered_bytes = 0; + } + Ok(()) + } + + fn append_spilled_batch( + &self, + accum: &mut PartitionAccumulator, + batch: RecordBatch, + ) -> Result<()> { + let batch_size = batch.get_array_memory_size(); + self.join_metrics.build_spilled_rows.add(batch.num_rows()); + self.join_metrics.build_spilled_bytes.add(batch_size); + { + let writer = self.ensure_build_spill_writer(accum)?; + writer.append_batch(&batch)?; + } + accum.spilled_bytes = accum.spilled_bytes.saturating_add(batch_size); + Ok(()) + } + + fn reset_partition_state(&mut self) { + for state in self.probe_states.iter_mut() { + if let Some(mut writer) = state.spill_in_progress.take() { + let _ = writer.finish(); + } + state.reset(); + } + self.probe_states.clear(); + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + self.probe_task_scheduler = HybridTaskScheduler::new(); + self.probe_scheduler_inflight.clear(); + self.probe_scheduler_waiting_for_stream.clear(); + self.probe_scheduler_active_streams = 0; + } + + for partition in self.build_partitions.iter_mut() { + if let BuildPartition::Spilled { + spill_file, + reservation, + .. + } = partition + { + if let Some(file) = spill_file.take() { + drop(file); + } + let placeholder = MemoryConsumer::new("released_build_partition") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let _ = mem::replace(reservation, placeholder); + } + } + + self.build_partitions.clear(); + self.matched_build_rows_per_partition.clear(); + self.current_partition = None; + self.pending_partitions.clear(); + self.placeholder_emitted = false; + self.right_alignment_start = 0; + self.unmatched_partition = 0; + self.unmatched_left_indices_cache = None; + self.unmatched_right_indices_cache = None; + self.unmatched_offset = 0; + self.probe_stream_finished = false; + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + self.partition_pending.clear(); + self.partition_descriptors.clear(); + self.bounds_waiter = None; + + self.resize_partition_vectors(); + + let reserved = self.memory_reservation.size(); + if reserved > 0 { + let _ = self.memory_reservation.try_shrink(reserved); + } + + self.state = PartitionedHashJoinState::PartitionBuildSide; + } + + fn next_partition_count(&self) -> Option { + if self.num_partitions >= self.max_partition_count { + return None; + } + + let mut next = self.num_partitions.saturating_mul(2); + if next <= self.num_partitions { + next = self.num_partitions.saturating_add(1); + } + if next > self.max_partition_count { + next = self.max_partition_count; + } + if next > self.num_partitions { + Some(next) + } else { + None + } + } + + fn repartition_worthwhile(&self, max_spilled_bytes: usize) -> bool { + let partitions = self.num_partitions.max(1); + let per_partition_budget = self.memory_threshold / partitions; + if per_partition_budget == 0 { + return false; + } + let cutoff = + std::cmp::max(per_partition_budget / 2, HYBRID_HASH_MIN_PARTITION_BYTES); + max_spilled_bytes > cutoff + } + + fn prepare_partition_queue(&mut self) { + self.pending_partitions.clear(); + let radix_bits = + self.num_partitions.next_power_of_two().trailing_zeros() as usize; + for part_id in 0..self.build_partitions.len() { + let (spilled_bytes, spilled_rows) = match &self.build_partitions[part_id] { + BuildPartition::Spilled { + spilled_bytes, + spilled_rows, + .. + } => (*spilled_bytes, *spilled_rows), + _ => (0, 0), + }; + if self.partition_descriptors.len() <= part_id { + self.partition_descriptors.resize_with(part_id + 1, || None); + } + if self.partition_pending.len() <= part_id { + self.partition_pending.resize(part_id + 1, false); + } + self.pending_partitions.push_back(PartitionDescriptor { + build_index: part_id, + root_index: part_id, + generation: self.partition_pass, + radix_bits, + hash_prefix: part_id as u64, + spilled_bytes, + spilled_rows, + }); + if let Some(desc) = self.pending_partitions.back().cloned() { + self.partition_descriptors[part_id] = Some(desc.clone()); + self.partition_pending[part_id] = true; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&desc); + } + } + } + + fn transition_to_next_partition(&mut self) { + if let Some(descriptor) = self.pending_partitions.pop_front() { + let build_index = descriptor.build_index; + if self.partition_descriptors.len() <= build_index { + self.partition_descriptors + .resize_with(build_index + 1, || None); + } + if self.partition_pending.len() <= build_index { + self.partition_pending.resize(build_index + 1, false); + } + self.partition_descriptors[build_index] = Some(descriptor.clone()); + self.partition_pending[build_index] = false; + self.current_partition = Some(build_index); + self.state = + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor, + }); + } else { + self.current_partition = None; + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + if !self.probe_scheduler_waiting_for_stream.is_empty() { + self.state = PartitionedHashJoinState::WaitingForProbe; + return; + } + } + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + } + } + + fn advance_to_next_partition(&mut self) { + self.current_partition = None; + self.transition_to_next_partition(); + } + + /// Report build-side bounds to the shared accumulator when dynamic filtering is enabled + fn poll_bounds_update( + &mut self, + cx: &mut Context<'_>, + build_data: &Arc, + ) -> Poll> { + if let Some(ref accumulator) = self.bounds_accumulator { + if self.bounds_waiter.is_none() { + // "[spill-join] partition={} reporting build bounds (rows={})", + // self.partition, + // build_data.batch().num_rows() + // ); + let accumulator = Arc::clone(accumulator); + let partition = self.partition; + let bounds = build_data.bounds.clone(); + self.bounds_waiter = Some(OnceFut::new(async move { + accumulator.report_partition_bounds(partition, bounds).await + })); + } + + if let Some(waiter) = self.bounds_waiter.as_mut() { + match waiter.get(cx) { + Poll::Ready(Ok(_)) => { + // "[spill-join] partition={} build bounds reported", + // self.partition + // ); + self.bounds_waiter = None; + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + // "[spill-join] partition={} waiting on shared bounds barrier", + // self.partition + // ); + return Poll::Pending; + } + } + } + } + + Poll::Ready(Ok(())) + } + + /// Ensure the build partition is loaded in-memory (reload if spilled) + fn ensure_build_partition_loaded( + &mut self, + cx: &mut Context<'_>, + part_id: usize, + ) -> Poll> { + let needs_reload = matches!( + self.build_partitions.get(part_id), + Some(BuildPartition::Spilled { .. }) + ); + if !needs_reload { + return Poll::Ready(Ok(())); + } + + // Kick off reload if needed + if self.pending_reload_partition.is_none() { + if let Some(BuildPartition::Spilled { spill_file, .. }) = + self.build_partitions.get_mut(part_id) + { + let spill_file = spill_file.take().ok_or_else(|| { + internal_datafusion_err!( + "spill file already consumed for this partition" + ) + })?; + let stream = self.build_spill_manager.read_spill_as_stream(spill_file)?; + self.pending_reload_stream = Some(stream); + self.pending_reload_batches.clear(); + self.pending_reload_partition = Some(part_id); + } + } + + // Drive stream forward + if self.pending_reload_partition == Some(part_id) { + if let Some(stream) = self.pending_reload_stream.as_mut() { + loop { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.pending_reload_batches.push(batch); + // Continue draining ready batches without yielding. + continue; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + // Concatenate + let first_schema = self + .pending_reload_batches + .get(0) + .ok_or_else(|| { + internal_datafusion_err!("empty spilled partition") + })? + .schema(); + let concatenated = concat_batches( + &first_schema, + self.pending_reload_batches.as_slice(), + ) + .map_err(DataFusionError::from)?; + // "Reloaded spilled build partition {} for probing (rows={})", + // part_id, + // concatenated.num_rows() + // ); + + // Grow global reservation conservatively by concatenated batch size + let concat_size = concatenated.get_array_memory_size(); + let _ = self.memory_reservation.try_grow(concat_size); + + // Recompute values and hashmap + let mut values: Vec = + Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push( + c.evaluate(&concatenated)? + .into_array(concatenated.num_rows())?, + ); + } + + let mut hash_map: Box = Box::new( + JoinHashMapU32::with_capacity(concatenated.num_rows()), + ); + self.hashes_buffer.clear(); + self.hashes_buffer.resize(concatenated.num_rows(), 0); + // Build HT for reloaded partition from precomputed key arrays (no re-eval) + create_hashes( + &values, + &self.random_state, + &mut self.hashes_buffer, + )?; + hash_map.extend_zero(concatenated.num_rows()); + let iter = self + .hashes_buffer + .iter() + .enumerate() + .map(|(i, h)| (i, h)); + hash_map.update_from_iter(Box::new(iter), 0); + + let new_reservation = MemoryConsumer::new("partition_reload") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + self.build_partitions[part_id] = BuildPartition::InMemory { + hash_map, + batch: concatenated, + values, + reservation: new_reservation, + }; + + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + // Shrink global reservation now that partition is resident with per-partition reservation + let _ = self.memory_reservation.try_shrink(concat_size); + return Poll::Ready(Ok(())); + } + Poll::Pending => { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + } + } + + Poll::Pending + } + /// Create a new partitioned hash join stream + pub fn new( + partition: usize, + schema: SchemaRef, + on_left: Vec, + on_right: Vec, + filter: Option, + join_type: JoinType, + right: SendableRecordBatchStream, + left_fut: OnceFut, + random_state: RandomState, + join_metrics: BuildProbeJoinMetrics, + probe_spill_metrics: SpillMetrics, + build_spill_metrics: SpillMetrics, + column_indices: Vec, + null_equality: NullEquality, + batch_size: usize, + mut num_partitions: usize, + mut max_partition_count: usize, + memory_threshold: usize, + memory_reservation: MemoryReservation, + runtime_env: Arc, + build_schema: SchemaRef, + probe_schema: SchemaRef, + right_side_ordered: bool, + bounds_accumulator: Option< + Arc, + >, + ) -> Result { + let probe_spill_manager = SpillManager::new( + runtime_env.clone(), + probe_spill_metrics, + Arc::clone(&probe_schema), + ); + + let build_spill_manager = SpillManager::new( + runtime_env.clone(), + build_spill_metrics, + Arc::clone(&build_schema), + ); + + let mem_limit = max_partitions_allowed_for_memory(memory_threshold) + .max(HYBRID_HASH_MIN_FANOUT); + max_partition_count = max_partition_count + .max(HYBRID_HASH_MIN_FANOUT) + .min(mem_limit); + num_partitions = num_partitions + .max(HYBRID_HASH_MIN_FANOUT) + .min(max_partition_count); + + #[cfg(feature = "hybrid_hash_join_scheduler")] + let scheduler_max_probe_streams = + std::cmp::max(1, std::cmp::min(4, num_partitions)); + + Ok(Self { + partition, + schema, + on_left, + on_right, + filter, + join_type, + right, + left_fut, + random_state, + join_metrics, + column_indices, + null_equality, + batch_size, + num_partitions, + max_partition_count, + memory_threshold, + state: PartitionedHashJoinState::PartitionBuildSide, + build_partitions: Vec::new(), + probe_states: (0..num_partitions) + .map(|_| ProbePartitionState::new()) + .collect(), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_task_scheduler: HybridTaskScheduler::new(), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_inflight: vec![false; num_partitions], + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_waiting_for_stream: VecDeque::new(), + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_active_streams: 0, + #[cfg(feature = "hybrid_hash_join_scheduler")] + probe_scheduler_max_streams: scheduler_max_probe_streams, + current_partition: None, + pending_partitions: VecDeque::new(), + probe_spill_manager, + build_spill_manager, + memory_reservation, + partition_pass: 0, + partition_pass_output_started: false, + runtime_env, + hashes_buffer: Vec::new(), + right_side_ordered, + placeholder_emitted: false, + right_alignment_start: 0, + bounds_accumulator, + bounds_waiter: None, + build_schema, + probe_schema, + matched_build_rows_per_partition: Vec::new(), + unmatched_partition: 0, + unmatched_left_indices_cache: None, + unmatched_right_indices_cache: None, + unmatched_offset: 0, + probe_stream_finished: false, + pending_reload_stream: None, + pending_reload_batches: Vec::new(), + pending_reload_partition: None, + matched_rows_per_part: vec![0; num_partitions], + emitted_rows_per_part: vec![0; num_partitions], + candidate_pairs_per_part: vec![0; num_partitions], + partition_pending: vec![false; num_partitions], + partition_descriptors: (0..num_partitions).map(|_| None).collect(), + }) + } + + /// Partition build-side data into multiple partitions + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn partition_build_side( + &mut self, + build_data: Arc, + ) -> Result>> { + HybridTaskScheduler::with_build_task(build_data).run_until_build_finished(self) + } + + /// Partition build-side data into multiple partitions (legacy serial path) + #[cfg(not(feature = "hybrid_hash_join_scheduler"))] + fn partition_build_side( + &mut self, + build_data: Arc, + ) -> Result>> { + self.partition_build_side_serial(build_data) + } + + /// Legacy build partitioning logic shared with the experimental scheduler. + pub(super) fn partition_build_side_serial( + &mut self, + build_data: Arc, + ) -> Result>> { + if self.partition_pass == 0 { + self.join_metrics.build_input_batches.add(1); + let total_rows: usize = build_data + .original_batches() + .iter() + .map(|b| b.num_rows()) + .sum(); + self.join_metrics.build_input_rows.add(total_rows); + } + + let build_total_size: usize = build_data + .original_batches() + .iter() + .map(|batch| batch.get_array_memory_size()) + .sum(); + if build_total_size <= self.memory_threshold { + self.num_partitions = 1; + self.max_partition_count = 1; + } + + let mut allow_repartition = !self.partition_pass_output_started; + loop { + self.reset_partition_state(); + + match self.try_partition_build_side(&build_data, allow_repartition)? { + PartitionBuildStatus::Ready(result) => { + return Ok(result); + } + PartitionBuildStatus::NeedMorePartitions { next_count } => { + if next_count <= self.num_partitions + || next_count == 0 + || next_count > self.max_partition_count + { + allow_repartition = false; + continue; + } + + self.num_partitions = next_count; + self.partition_pass += 1; + self.partition_pass_output_started = false; + allow_repartition = true; + } + } + } + } + + fn try_partition_build_side( + &mut self, + build_data: &Arc, + allow_repartition: bool, + ) -> Result { + self.build_partitions = Vec::with_capacity(self.num_partitions); + self.matched_build_rows_per_partition = Vec::with_capacity(self.num_partitions); + + let mut partition_accumulators = (0..self.num_partitions) + .map(|_| PartitionAccumulator::new()) + .collect::>(); + let mut repartition_request: Option = None; + let mut max_spilled_bytes: usize = 0; + let mut any_spilled = false; + + for batch in build_data.original_batches() { + let mut keys_values: Vec = Vec::with_capacity(self.on_left.len()); + for expr in &self.on_left { + keys_values.push(expr.evaluate(batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + + let mut indices_per_part: Vec> = + vec![Vec::new(); self.num_partitions]; + for (row_idx, hash) in hashes.iter().enumerate() { + let build_index = self.partition_for_hash(*hash); + indices_per_part[build_index].push(row_idx as u32); + } + + for (build_index, indices) in indices_per_part.into_iter().enumerate() { + if indices.is_empty() { + continue; + } + + let idx_array = UInt32Array::from(indices); + let mut filtered_columns: Vec = + Vec::with_capacity(batch.num_columns()); + for col in batch.columns() { + filtered_columns.push( + take(col, &idx_array, None).map_err(DataFusionError::from)?, + ); + } + let filtered_batch = + RecordBatch::try_new(batch.schema(), filtered_columns) + .map_err(DataFusionError::from)?; + let batch_size = filtered_batch.get_array_memory_size(); + let accum = &mut partition_accumulators[build_index]; + accum.total_rows += filtered_batch.num_rows(); + + if accum.spill_writer.is_some() { + self.append_spilled_batch(accum, filtered_batch)?; + continue; + } + + match self.memory_reservation.try_grow(batch_size) { + Ok(_) => { + accum.buffered_bytes += batch_size; + accum.buffered_batches.push(filtered_batch); + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + if self.memory_reservation.size() > self.memory_threshold { + if allow_repartition { + let partition_estimate = accum.buffered_bytes; + if self.repartition_worthwhile(partition_estimate) { + if let Some(next_count) = self.next_partition_count() + { + repartition_request = Some(next_count); + break; + } + } + } + if !self.runtime_env.disk_manager.tmp_files_enabled() { + return Err(internal_datafusion_err!( + "Insufficient memory for build partitioning and spilling is disabled" + )); + } + self.spill_partition(build_index, accum)?; + } + } + Err(_) => { + if allow_repartition { + let partition_estimate = + accum.buffered_bytes.saturating_add(batch_size); + if self.repartition_worthwhile(partition_estimate) { + if let Some(next_count) = self.next_partition_count() { + repartition_request = Some(next_count); + break; + } + } + } + if !self.runtime_env.disk_manager.tmp_files_enabled() { + return Err(internal_datafusion_err!( + "Unable to allocate memory for build partition" + )); + } + self.spill_partition(build_index, accum)?; + self.append_spilled_batch(accum, filtered_batch)?; + } + } + + if repartition_request.is_some() { + break; + } + } + + if repartition_request.is_some() { + break; + } + } + + if let Some(next_count) = repartition_request { + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } + + self.build_partitions.reserve(self.num_partitions); + self.matched_build_rows_per_partition + .reserve(self.num_partitions); + + for part_id in 0..self.num_partitions { + let mut accum = mem::take(&mut partition_accumulators[part_id]); + max_spilled_bytes = max_spilled_bytes.max(accum.spilled_bytes); + if accum.spill_writer.is_some() { + if !accum.buffered_batches.is_empty() { + self.spill_partition(part_id, &mut accum)?; + } + if let Some(mut writer) = accum.spill_writer.take() { + let spill_file = writer + .finish()? + .ok_or_else(|| internal_datafusion_err!("expected spill file"))?; + let mut matched_bitmap = BooleanBufferBuilder::new(accum.total_rows); + matched_bitmap.append_n(accum.total_rows, false); + self.matched_build_rows_per_partition.push(matched_bitmap); + let reservation = MemoryConsumer::new("partition_spilled") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + any_spilled = true; + self.build_partitions.push(BuildPartition::Spilled { + spill_file: Some(spill_file), + reservation, + spilled_bytes: accum.spilled_bytes, + spilled_rows: accum.total_rows, + }); + } + continue; + } + + if accum.buffered_batches.is_empty() { + self.matched_build_rows_per_partition + .push(BooleanBufferBuilder::new(0)); + self.build_partitions.push(BuildPartition::Empty); + continue; + } + + let mut buffered_batches = accum.buffered_batches; + let partition_batch = if buffered_batches.len() == 1 { + buffered_batches.pop().unwrap() + } else { + let batch_refs: Vec<_> = buffered_batches.iter().collect(); + concat_batches(&self.build_schema, batch_refs)? + }; + let num_rows = partition_batch.num_rows(); + let partition_values = self + .on_left + .iter() + .map(|expr| expr.evaluate(&partition_batch)?.into_array(num_rows)) + .collect::>>()?; + let fixed_size_u32 = size_of::(); + let fixed_size_u64 = size_of::(); + let mut hash_map: Box = if num_rows > u32::MAX as usize { + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size_u64)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU64::with_capacity(num_rows)) + } else { + let estimated_hashtable_size = + estimate_memory_size::<(u32, u64)>(num_rows, fixed_size_u32)?; + self.memory_reservation.try_grow(estimated_hashtable_size)?; + self.join_metrics + .build_mem_used + .set_max(self.memory_reservation.size()); + Box::new(JoinHashMapU32::with_capacity(num_rows)) + }; + + self.hashes_buffer.clear(); + self.hashes_buffer.resize(num_rows, 0); + create_hashes( + &partition_values, + &self.random_state, + &mut self.hashes_buffer, + )?; + hash_map.extend_zero(num_rows); + let iter = self + .hashes_buffer + .iter() + .enumerate() + .map(|(idx, hash)| (idx, hash)); + hash_map.update_from_iter(Box::new(iter), 0); + + let mut matched_bitmap = BooleanBufferBuilder::new(num_rows); + matched_bitmap.append_n(num_rows, false); + self.matched_build_rows_per_partition.push(matched_bitmap); + + let reservation = MemoryConsumer::new("partition_memory") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + + let approx_partition_size = partition_batch.get_array_memory_size() + + partition_values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::(); + self.join_metrics.build_mem_used.set_max( + self.memory_reservation + .size() + .saturating_add(approx_partition_size), + ); + + self.build_partitions.push(BuildPartition::InMemory { + hash_map, + batch: partition_batch, + values: partition_values, + reservation, + }); + } + + if allow_repartition + && (max_spilled_bytes > self.memory_threshold || any_spilled) + && self.repartition_worthwhile(max_spilled_bytes) + { + if let Some(next_count) = self.next_partition_count() { + return Ok(PartitionBuildStatus::NeedMorePartitions { next_count }); + } + } + + self.prepare_partition_queue(); + self.partition_pass_output_started = true; + self.transition_to_next_partition(); + + Ok(PartitionBuildStatus::Ready(StatefulStreamResult::Continue)) + } + /// Release resources associated with a finished partition when safe to do so. + /// Only releases memory eagerly when we don't need unmatched rows in the final phase. + fn release_partition_resources(&mut self, build_index: usize) { + if need_produce_result_in_final(self.join_type) { + return; + } + + if build_index >= self.build_partitions.len() { + return; + } + + // Take ownership of the old partition to drop heavy resources + let placeholder_reservation = + MemoryConsumer::new("partition_released_placeholder") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let old_partition = mem::replace( + &mut self.build_partitions[build_index], + BuildPartition::Released { + reservation: placeholder_reservation, + }, + ); + + match old_partition { + BuildPartition::InMemory { + batch, + values, + reservation, + .. + } => { + // Estimate memory held by this partition and shrink global reservation + let mut estimated_size = batch.get_array_memory_size(); + estimated_size += values + .iter() + .map(|a| a.get_array_memory_size()) + .sum::(); + let _ = self.memory_reservation.try_shrink(estimated_size); + + // Replace with an empty in-memory partition to keep indexing stable + let empty_batch = RecordBatch::new_empty(batch.schema()); + let empty_values: Vec = self + .on_left + .iter() + .filter_map(|expr| expr.evaluate(&empty_batch).ok()) + .filter_map(|v| v.into_array(empty_batch.num_rows()).ok()) + .collect(); + let empty_hash_map: Box = + Box::new(JoinHashMapU32::with_capacity(0)); + + self.build_partitions[build_index] = BuildPartition::InMemory { + hash_map: empty_hash_map, + batch: empty_batch, + values: empty_values, + reservation, + }; + } + BuildPartition::Spilled { reservation, .. } => { + // Transition to Released; no files remain + self.build_partitions[build_index] = + BuildPartition::Released { reservation }; + } + BuildPartition::Released { reservation } => { + self.build_partitions[build_index] = + BuildPartition::Released { reservation }; + } + BuildPartition::Empty => { + // no-op + } + } + } + + fn partition_has_pending_probe(&self, part_id: usize) -> bool { + if let Some(state) = self.probe_states.get(part_id) { + if state.batch_position < state.buffered.batches.len() { + return true; + } + if state.active_batch.is_some() { + return true; + } + if !state.spill_files.is_empty() { + return true; + } + if state.pending_stream.is_some() { + return true; + } + if state.spill_in_progress.is_some() { + return true; + } + } + false + } + + /// Attempts to load the next buffered probe batch for `part_id`. + pub(super) fn take_buffered_probe_batch( + &mut self, + part_id: usize, + ) -> Result> { + if let Some(state) = self.probe_states.get_mut(part_id) { + if state.batch_position < state.buffered.batches.len() { + let pos = state.batch_position; + let batch = state.buffered.batches[pos].clone(); + let values = state.buffered.values[pos].clone(); + let hashes = state.buffered.hashes[pos].clone(); + state.batch_position = state.batch_position.saturating_add(1); + state.active_batch = Some(batch.clone()); + state.active_values = values; + state.active_hashes = hashes; + state.active_offset = (0, None); + if state.batch_position >= state.buffered.batches.len() { + state.buffered = ProbePartition::new(); + state.batch_position = 0; + state.buffered_rows = 0; + } + if let Some(b) = state.active_batch.as_ref() { + state.consumed_rows = + state.consumed_rows.saturating_add(b.num_rows()); + } + return Ok(Some(batch)); + } + } + Ok(None) + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn try_acquire_probe_stream_slot(&mut self) -> bool { + if self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { + self.probe_scheduler_active_streams += 1; + true + } else { + false + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn release_probe_stream_slot(&mut self) { + if self.probe_scheduler_active_streams > 0 { + self.probe_scheduler_active_streams -= 1; + } + self.wake_stream_waiter(); + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn enqueue_stream_waiter(&mut self, part_id: usize) { + if part_id >= self.partition_pending.len() { + return; + } + if self + .probe_scheduler_waiting_for_stream + .iter() + .any(|&v| v == part_id) + { + return; + } + self.probe_scheduler_waiting_for_stream.push_back(part_id); + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn wake_stream_waiter(&mut self) { + while self.probe_scheduler_active_streams < self.probe_scheduler_max_streams { + if let Some(next_part) = self.probe_scheduler_waiting_for_stream.pop_front() { + if next_part >= self.partition_pending.len() { + continue; + } + if self.partition_pending[next_part] { + continue; + } + if let Some(Some(desc)) = + self.partition_descriptors.get(next_part).map(|d| d.clone()) + { + self.partition_pending[next_part] = true; + let waiting_for_probe = + matches!(self.state, PartitionedHashJoinState::WaitingForProbe); + self.pending_partitions.push_back(desc); + if waiting_for_probe { + self.transition_to_next_partition(); + } + break; + } + } else { + break; + } + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn poll_probe_stage_task( + &mut self, + cx: &mut Context<'_>, + descriptor: &PartitionDescriptor, + ) -> Result { + let part_id = descriptor.build_index; + self.schedule_probe_task(descriptor); + + let mut iterations = self.probe_task_scheduler.len(); + while iterations > 0 { + iterations -= 1; + let Some(task) = self.probe_task_scheduler.pop_task() else { + break; + }; + match task { + SchedulerTask::Probe(probe_task) => { + match SchedulerTask::Probe(probe_task).poll(self, Some(cx))? { + TaskPoll::ProbeReady(desc) => { + let ready_part = desc.build_index; + if ready_part >= self.probe_scheduler_inflight.len() { + self.probe_scheduler_inflight + .resize(ready_part + 1, false); + } + self.probe_scheduler_inflight[ready_part] = false; + if ready_part == part_id { + return Ok(ProbeTaskStatus::Ready); + } else { + if ready_part >= self.partition_pending.len() { + self.partition_pending.resize(ready_part + 1, false); + } + if !self.partition_pending[ready_part] { + self.pending_partitions.push_back(desc.clone()); + self.partition_pending[ready_part] = true; + } + } + } + TaskPoll::Pending(next_task) => { + self.probe_task_scheduler.push_task(next_task); + } + TaskPoll::YieldProbe { + task: next_task, + descriptor: desc, + } => { + let wait_part = desc.build_index; + if wait_part == part_id { + self.probe_task_scheduler.push_task(next_task); + return Ok(ProbeTaskStatus::WaitingForStream); + } else { + self.probe_task_scheduler.push_task(next_task); + self.enqueue_stream_waiter(wait_part); + } + } + TaskPoll::ProbeFinished(desc) => { + let finished_part = desc.build_index; + if finished_part >= self.probe_scheduler_inflight.len() { + self.probe_scheduler_inflight + .resize(finished_part + 1, false); + } + self.probe_scheduler_inflight[finished_part] = false; + if finished_part == part_id { + return Ok(ProbeTaskStatus::Finished); + } else { + if finished_part >= self.partition_pending.len() { + self.partition_pending + .resize(finished_part + 1, false); + } + if !self.partition_pending[finished_part] { + self.pending_partitions.push_back(desc.clone()); + self.partition_pending[finished_part] = true; + } + } + } + TaskPoll::BuildFinished(_) => {} + } + } + other_task => { + // Unexpected task type for probe scheduling; push back to preserve semantics. + self.probe_task_scheduler.push_task(other_task); + } + } + } + + let queue_len = self.probe_task_scheduler.len(); + if queue_len > 0 { + cx.waker().wake_by_ref(); + } + Ok(ProbeTaskStatus::Pending) + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + pub(super) fn poll_probe_data_for_partition( + &mut self, + part_id: usize, + cx: &mut Context<'_>, + ) -> Result { + if self.take_buffered_probe_batch(part_id)?.is_some() { + return Ok(ProbeDataPoll::Ready); + } + + let has_spilled_probe = { + let state = self.probe_state(part_id)?; + state.spill_in_progress.is_some() + || !state.spill_files.is_empty() + || state.pending_stream.is_some() + }; + + if !has_spilled_probe { + return Ok(ProbeDataPoll::Finished); + } + + loop { + let needs_stream = { + let state = self.probe_state(part_id)?; + state.pending_stream.is_none() + }; + if needs_stream { + let mut next_file = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state.spill_files.pop_front() + }; + if next_file.is_none() && self.finalize_spilled_partition(part_id)? { + next_file = { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing partition") + })?; + state.spill_files.pop_front() + }; + } + if let Some(file) = next_file { + if !self.try_acquire_probe_stream_slot() { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing partition") + })?; + state.spill_files.push_front(file); + return Ok(ProbeDataPoll::NeedStream); + } + let stream = self.probe_spill_manager.read_spill_as_stream(file)?; + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state.pending_stream = Some(stream); + } else { + let writer_open = { + let state = self.probe_state(part_id)?; + state.spill_in_progress.is_some() + }; + if self.probe_stream_finished && !writer_open { + return Ok(ProbeDataPoll::Finished); + } else { + return Ok(ProbeDataPoll::Pending); + } + } + } + + let poll_result = { + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state + .pending_stream + .as_mut() + .map(|stream| stream.poll_next_unpin(cx)) + }; + + match poll_result { + Some(Poll::Ready(Some(Ok(batch)))) => { + let (values, hashes) = self.prepare_probe_values(&batch)?; + let state = self + .probe_states + .get_mut(part_id) + .ok_or_else(|| internal_datafusion_err!("missing partition"))?; + state.active_batch = Some(batch); + state.active_values = values; + state.active_hashes = hashes; + state.active_offset = (0, None); + if let Some(b) = state.active_batch.as_ref() { + state.consumed_rows = + state.consumed_rows.saturating_add(b.num_rows()); + } + return Ok(ProbeDataPoll::Ready); + } + Some(Poll::Ready(Some(Err(e)))) => return Err(e), + Some(Poll::Ready(None)) => { + { + let state = + self.probe_states.get_mut(part_id).ok_or_else(|| { + internal_datafusion_err!("missing partition") + })?; + state.pending_stream = None; + } + self.release_probe_stream_slot(); + continue; + } + Some(Poll::Pending) | None => return Ok(ProbeDataPoll::Pending), + } + } + } + + #[cfg(feature = "hybrid_hash_join_scheduler")] + fn prepare_probe_values( + &self, + batch: &RecordBatch, + ) -> Result<(Vec, Vec)> { + let mut keys_values: Vec = Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + keys_values.push(c.evaluate(batch)?.into_array(batch.num_rows())?); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes(&keys_values, &self.random_state, &mut hashes)?; + Ok((keys_values, hashes)) + } + + /// Process a specific partition + fn process_partition( + &mut self, + cx: &mut Context<'_>, + partition_state: &ProcessPartitionState, + ) -> Poll>>> { + let build_index = partition_state.descriptor.build_index; + + // Guard against invalid partition ids (off-by-one protection) + if build_index >= self.build_partitions.len() { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if self.maybe_recursive_repartition(&partition_state.descriptor)? { + self.current_partition = None; + self.transition_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + if self.current_partition != Some(build_index) { + self.current_partition = Some(build_index); + } + + // Do not buffer probe side here; selection happens below depending on num_partitions + + // (Spill reload handled by ensure_build_partition_loaded earlier if needed) + + // (Build partition will be immutably borrowed later within a narrower scope) + + // Ensure the build partition is ready (reload if spilled) BEFORE any immutable borrows + match self.ensure_build_partition_loaded(cx, build_index) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + + // Ensure probe side is fully buffered into per-partition containers + if !self.probe_stream_finished { + match self.buffer_probe_side(cx) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + let no_current_data = !self.partition_has_pending_probe(build_index); + let no_other_pending = self.pending_partitions.is_empty(); + if no_current_data && no_other_pending { + return Poll::Pending; + } + } + } + } + + // Select next probe batch for current partition + let mut has_active_batch = match self.probe_state(build_index) { + Ok(state) => state.active_batch.is_some(), + Err(e) => return Poll::Ready(Err(e)), + }; + + #[cfg(feature = "hybrid_hash_join_scheduler")] + { + if !has_active_batch { + match self.poll_probe_stage_task(cx, &partition_state.descriptor)? { + ProbeTaskStatus::Ready => { + has_active_batch = true; + } + ProbeTaskStatus::Pending => { + return Poll::Pending; + } + ProbeTaskStatus::WaitingForStream => { + self.enqueue_stream_waiter(build_index); + self.current_partition = None; + self.transition_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + ProbeTaskStatus::Finished => { + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } + } + } + + #[cfg(not(feature = "hybrid_hash_join_scheduler"))] + { + if !has_active_batch { + if self.take_buffered_probe_batch(build_index)?.is_some() { + has_active_batch = true; + } + } + + if !has_active_batch { + let has_spilled_probe = match self.probe_state(build_index) { + Ok(state) => { + state.spill_in_progress.is_some() + || !state.spill_files.is_empty() + || state.pending_stream.is_some() + } + Err(e) => return Poll::Ready(Err(e)), + }; + + if has_spilled_probe { + loop { + let needs_stream = match self.probe_state(build_index) { + Ok(state) => state.pending_stream.is_none(), + Err(e) => return Poll::Ready(Err(e)), + }; + + if needs_stream { + let mut next_file = match self.probe_state_mut(build_index) { + Ok(state) => state.spill_files.pop_front(), + Err(e) => return Poll::Ready(Err(e)), + }; + if next_file.is_none() + && self.finalize_spilled_partition(build_index)? + { + next_file = match self.probe_state_mut(build_index) { + Ok(state) => state.spill_files.pop_front(), + Err(e) => return Poll::Ready(Err(e)), + }; + } + if let Some(file) = next_file { + let stream = self + .probe_spill_manager + .read_spill_as_stream(file)?; + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = Some(stream), + Err(e) => return Poll::Ready(Err(e)), + } + } else { + let should_release = match self.probe_state(build_index) { + Ok(state) => { + self.probe_stream_finished + && state.spill_in_progress.is_none() + && state.pending_stream.is_none() + } + Err(e) => return Poll::Ready(Err(e)), + }; + if should_release { + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = None, + Err(e) => return Poll::Ready(Err(e)), + } + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok( + StatefulStreamResult::Continue, + )); + } else { + return Poll::Pending; + } + } + } + + let poll_result = { + let state = match self.probe_state_mut(build_index) { + Ok(state) => state, + Err(e) => return Poll::Ready(Err(e)), + }; + if let Some(stream) = state.pending_stream.as_mut() { + stream.poll_next_unpin(cx) + } else { + return Poll::Pending; + } + }; + + match poll_result { + Poll::Ready(Some(Ok(batch))) => { + let mut keys_values: Vec = + Vec::with_capacity(self.on_right.len()); + for c in &self.on_right { + let v = c + .evaluate(&batch)? + .into_array(batch.num_rows())?; + keys_values.push(v); + } + let mut hashes = vec![0u64; batch.num_rows()]; + create_hashes( + &keys_values, + &self.random_state, + &mut hashes, + )?; + + let state = match self.probe_state_mut(build_index) { + Ok(state) => state, + Err(e) => return Poll::Ready(Err(e)), + }; + state.active_batch = Some(batch); + state.active_values = keys_values; + state.active_hashes = hashes; + state.active_offset = (0, None); + if let Some(b) = state.active_batch.as_ref() { + state.consumed_rows = + state.consumed_rows.saturating_add(b.num_rows()); + } + has_active_batch = true; + break; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + match self.probe_state_mut(build_index) { + Ok(state) => state.pending_stream = None, + Err(e) => return Poll::Ready(Err(e)), + } + continue; + } + Poll::Pending => return Poll::Pending, + } + } + } else { + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } + } + + if !has_active_batch { + self.release_partition_resources(build_index); + self.advance_to_next_partition(); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + // At this point we have a current probe batch for this partition + let (result, build_ids_to_mark, next_offset, next_joined_idx) = { + let ( + probe_batch, + probe_values, + probe_hashes, + current_offset, + prev_joined_idx, + ) = { + let state = match self.probe_state(build_index) { + Ok(state) => state, + Err(e) => return Poll::Ready(Err(e)), + }; + let batch = state + .active_batch + .as_ref() + .ok_or_else(|| internal_datafusion_err!("expected probe batch"))? + .clone(); + let values = state.active_values.clone(); + let hashes = state.active_hashes.clone(); + ( + batch, + values, + hashes, + state.active_offset, + state.joined_probe_idx, + ) + }; + + let (build_hashmap, build_batch, build_values) = + match self.build_partitions.get(build_index) { + Some(BuildPartition::InMemory { + hash_map, + batch, + values, + .. + }) => (&**hash_map, batch, values as &Vec), + Some(BuildPartition::Spilled { .. }) => { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Some(BuildPartition::Released { .. }) + | Some(BuildPartition::Empty) + | None => { + return Poll::Ready(internal_err!( + "Missing or invalid build partition" + )); + } + }; + // Debug: log ON expressions and output mapping once we have both sides + // Lookup against hash map with limit + let (probe_indices, build_indices, next_offset) = build_hashmap + .get_matched_indices_with_limit_offset( + &probe_hashes, + self.batch_size, + current_offset, + ); + + let build_indices: UInt64Array = build_indices.into(); + let probe_indices: UInt32Array = probe_indices.into(); + + // Track candidate pairs before equality + self.candidate_pairs_per_part[build_index] = self.candidate_pairs_per_part + [build_index] + .saturating_add(build_indices.len()); + // "[spill-join] Candidates before equality: build_ids={}, probe_ids={}, build_rows={}, probe_rows={}", + // build_indices.len(), + // probe_indices.len(), + // build_batch.num_rows(), + // probe_batch.num_rows() + // ); + + // Resolve hash collisions + let (build_indices, probe_indices) = equal_rows_arr( + &build_indices, + &probe_indices, + build_values, + &probe_values, + self.null_equality, + )?; + + // Apply residual join filter if present + let mut build_indices = build_indices; + let mut probe_indices = probe_indices; + if let Some(filter) = &self.filter { + let (filtered_build_indices, filtered_probe_indices) = + apply_join_filter_to_indices( + build_batch, + &probe_batch, + build_indices, + probe_indices, + filter, + JoinSide::Left, + None, + )?; + + build_indices = filtered_build_indices; + probe_indices = filtered_probe_indices; + } + + // Capture matched build indices prior to alignment so we can mark bitmaps even if + // the join type drops them (e.g. LeftAnti emits matches only in the final phase). + let build_indices_for_marking = + if need_produce_result_in_final(self.join_type) { + Some(build_indices.clone()) + } else { + None + }; + + // Accumulate matched rows per partition + self.matched_rows_per_part[build_index] = self.matched_rows_per_part + [build_index] + .saturating_add(build_indices.len()); + + // Compute alignment window (used by adjust_indices for all join types) + let last_joined_right_idx = match probe_indices.len() { + 0 => None, + n => Some(probe_indices.value(n - 1) as usize), + }; + let probe_num_rows = probe_batch.num_rows(); + let mut index_alignment_range_start = prev_joined_idx.map_or(0, |v| v + 1); + let mut index_alignment_range_end = if next_offset.is_none() { + probe_num_rows + } else { + last_joined_right_idx.map_or(index_alignment_range_start, |v| v + 1) + }; + + if index_alignment_range_start > probe_num_rows { + index_alignment_range_start = probe_num_rows; + } + if index_alignment_range_end > probe_num_rows { + index_alignment_range_end = probe_num_rows; + } + if index_alignment_range_end < index_alignment_range_start { + index_alignment_range_end = index_alignment_range_start; + } + + let (build_indices, probe_indices) = adjust_indices_by_join_type( + build_indices, + probe_indices, + index_alignment_range_start..index_alignment_range_end, + self.join_type, + self.right_side_ordered, + )?; + + // Only right-oriented joins need to preserve alignment state across batches + let needs_alignment = matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ); + + // Debug counter: after alignment (or effective no-op for other join types) + + // Prepare ids for marking after we release borrows. Prefer the pre-alignment + // matches (for join types like LeftAnti) so bitmap tracking remains accurate. + let build_ids_to_mark: Vec = + if let Some(indices) = build_indices_for_marking { + indices.values().to_vec() + } else { + build_indices.values().to_vec() + }; + // Track last joined probe row only for right-oriented joins; otherwise clear it + let next_joined_idx = if needs_alignment && next_offset.is_some() { + last_joined_right_idx + } else { + None + }; + + // Build output batch depending on join side semantics + let result = if matches!( + self.join_type, + JoinType::RightMark | JoinType::RightSemi | JoinType::RightAnti + ) { + if matches!(self.join_type, JoinType::RightMark) { + } else { + // "[spill-join] Building output with JoinSide::Right ({:?})", + // self.join_type + // ); + } + let right_indices_u64 = uint32_to_uint64_indices(&probe_indices); + build_batch_from_indices( + &self.schema, + &probe_batch, + build_batch, + &right_indices_u64, + &probe_indices, + &self.column_indices, + JoinSide::Right, + )? + } else { + build_batch_from_indices( + &self.schema, + build_batch, + &probe_batch, + &build_indices, + &probe_indices, + &self.column_indices, + JoinSide::Left, + )? + }; + + let emitted_rows = result.num_rows(); + self.emitted_rows_per_part[build_index] = + self.emitted_rows_per_part[build_index].saturating_add(emitted_rows); + (result, build_ids_to_mark, next_offset, next_joined_idx) + }; + + // Mark matched build-side rows for outer joins (use current partition's bitmap) + if let Some(bitmap) = self.matched_build_rows_per_partition.get_mut(build_index) { + for build_idx in build_ids_to_mark { + bitmap.set_bit(build_idx as usize, true); + } + } + + // Update offset or fetch a new probe batch + match self.probe_state_mut(build_index) { + Ok(state) => { + if let Some(offset) = next_offset { + state.active_offset = offset; + state.joined_probe_idx = next_joined_idx; + } else { + state.active_batch = None; + state.active_values.clear(); + state.active_hashes.clear(); + state.active_offset = (0, None); + state.joined_probe_idx = None; + #[cfg(feature = "hybrid_hash_join_scheduler")] + self.schedule_probe_task(&partition_state.descriptor); + } + } + Err(e) => return Poll::Ready(Err(e)), + } + + if result.num_rows() == 0 { + // "[spill-join] Skipping empty batch emission (partition={})", + // build_index + // ); + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + self.join_metrics.output_batches.add(1); + self.join_metrics.baseline.record_output(result.num_rows()); + // "[spill-join] Emitting batch: rows={} (partition={})", + // result.num_rows(), + // build_index + // ); + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))) + } + + /// Handle unmatched rows for outer joins (poll-based, non-blocking spill reload) + fn handle_unmatched_rows( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + if !need_produce_result_in_final(self.join_type) { + self.state = PartitionedHashJoinState::Completed; + return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); + } + + // If we have cached unmatched indices for current partition, emit them chunk-by-chunk + if let (Some(left_all), Some(right_all)) = ( + self.unmatched_left_indices_cache.as_ref(), + self.unmatched_right_indices_cache.as_ref(), + ) { + let total = left_all.len(); + if self.unmatched_offset < total { + let remaining = total - self.unmatched_offset; + let to_emit = remaining.min(self.batch_size); + + let left_chunk_ref = left_all.slice(self.unmatched_offset, to_emit); + let right_chunk_ref = right_all.slice(self.unmatched_offset, to_emit); + let left_chunk = left_chunk_ref + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("failed to downcast left indices chunk") + })?; + let right_chunk = right_chunk_ref + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!("failed to downcast right indices chunk") + })?; + + // Use current partition's build batch + let partition = self + .build_partitions + .get(self.unmatched_partition) + .ok_or_else(|| { + internal_datafusion_err!( + "missing build partition during unmatched cached emission" + ) + })?; + let build_batch = match partition { + BuildPartition::InMemory { batch, .. } => batch, + BuildPartition::Spilled { .. } => { + // Should not happen because we only cache after loading InMemory indices + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + BuildPartition::Released { .. } => { + return Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + BuildPartition::Empty => { + return Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + }; + + let empty_right_batch = + RecordBatch::new_empty(Arc::clone(&self.probe_schema)); + // "Emitting unmatched rows chunk: partition={}, offset={}, size={} (total={})", + // self.unmatched_partition, + // self.unmatched_offset, + // to_emit, + // total + // ); + + let result = build_batch_from_indices( + &self.schema, + build_batch, + &empty_right_batch, + left_chunk, + right_chunk, + &self.column_indices, + JoinSide::Left, + )?; + + self.unmatched_offset += to_emit; + if self.unmatched_offset >= total { + // finished this partition's unmatched rows + self.unmatched_left_indices_cache = None; + self.unmatched_right_indices_cache = None; + self.unmatched_offset = 0; + // "Finished emitting unmatched rows for partition {}", + // self.unmatched_partition + // ); + self.unmatched_partition += 1; + } + + return Poll::Ready(Ok(StatefulStreamResult::Ready(Some(result)))); + } else { + // Safety: should not reach here; reset caches + self.unmatched_left_indices_cache = None; + self.unmatched_right_indices_cache = None; + self.unmatched_offset = 0; + } + } + + // Process unmatched rows for the current partition + if self.unmatched_partition < self.build_partitions.len() { + let partition = self + .build_partitions + .get_mut(self.unmatched_partition) + .ok_or_else(|| { + internal_datafusion_err!( + "missing build partition during unmatched processing" + ) + })?; + + match partition { + BuildPartition::InMemory { batch: _batch, .. } => { + // Get unmatched indices for this partition using its bitmap + let (left_indices, right_indices) = if let Some(bitmap) = self + .matched_build_rows_per_partition + .get(self.unmatched_partition) + { + get_final_indices_from_bit_map(bitmap, self.join_type) + } else { + // If no bitmap, skip this partition + self.unmatched_partition += 1; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + }; + // "Unmatched calculation for partition {} -> {} rows", + // self.unmatched_partition, + // left_indices.len() + // ); + + if left_indices.len() > 0 { + // Cache the full indices and emit first chunk via cached path next call + self.unmatched_left_indices_cache = Some(left_indices.clone()); + self.unmatched_right_indices_cache = Some(right_indices.clone()); + self.unmatched_offset = 0; + // Fall-through into cached emission on next invocation + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } else { + // No unmatched rows in this partition, move to next + self.unmatched_partition += 1; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } + BuildPartition::Spilled { spill_file, .. } => { + // Non-blocking reload of spilled partition for unmatched rows + if self.pending_reload_partition.is_none() { + let taken = spill_file.take().ok_or_else(|| { + internal_datafusion_err!( + "spill file already consumed for unmatched" + ) + })?; + let stream = + self.build_spill_manager.read_spill_as_stream(taken)?; + self.pending_reload_stream = Some(stream); + self.pending_reload_batches.clear(); + self.pending_reload_partition = Some(self.unmatched_partition); + } + + if self.pending_reload_partition == Some(self.unmatched_partition) { + if let Some(stream) = self.pending_reload_stream.as_mut() { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + // "Reload stream yielded batch for build partition {} (rows={})", + // self.unmatched_partition, + // batch.num_rows() + // ); + self.pending_reload_batches.push(batch); + return Poll::Pending; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => { + let first_schema = self + .pending_reload_batches + .get(0) + .ok_or_else(|| { + internal_datafusion_err!( + "empty spilled partition for unmatched" + ) + })? + .schema(); + let concatenated = concat_batches( + &first_schema, + self.pending_reload_batches.as_slice(), + ) + .map_err(DataFusionError::from)?; + // "Reloaded spilled build partition {} for unmatched rows (rows={})", + // self.unmatched_partition, + // concatenated.num_rows() + // ); + + let new_reservation = + MemoryConsumer::new("partition_reload_unmatched") + .with_can_spill(true) + .register(&self.runtime_env.memory_pool); + let mut values: Vec = + Vec::with_capacity(self.on_left.len()); + for c in &self.on_left { + values.push( + c.evaluate(&concatenated)? + .into_array(concatenated.num_rows())?, + ); + } + let hash_map: Box = + Box::new(JoinHashMapU32::with_capacity( + concatenated.num_rows(), + )); + self.build_partitions[self.unmatched_partition] = + BuildPartition::InMemory { + hash_map, + batch: concatenated, + values, + reservation: new_reservation, + }; + // "Prepared spilled partition {} as InMemory for unmatched emission", + // self.unmatched_partition + // ); + + // Clear pending + self.pending_reload_stream = None; + self.pending_reload_batches.clear(); + self.pending_reload_partition = None; + + // Continue; next iteration will handle InMemory branch + return Poll::Ready(Ok( + StatefulStreamResult::Continue, + )); + } + Poll::Pending => { + // Yield until more data is available from reload stream + // "Reload stream pending for build partition {} (accumulated_batches={})", + // self.unmatched_partition, + // self.pending_reload_batches.len() + // ); + return Poll::Pending; + } + } + } + } + Poll::Pending + } + BuildPartition::Released { .. } => { + // Nothing to emit; advance + self.unmatched_partition += 1; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + BuildPartition::Empty => { + self.unmatched_partition += 1; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + } + } else { + // All partitions processed + self.state = PartitionedHashJoinState::Completed; + return Poll::Ready(Ok(StatefulStreamResult::Ready(None))); + } + } +} + +impl RecordBatchStream for PartitionedHashJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +impl Stream for PartitionedHashJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + loop { + match self.state.clone() { + PartitionedHashJoinState::PartitionBuildSide => { + // Collect build side and partition it + let left_data = { + let fut = &mut self.left_fut; + ready!(fut.get_shared(cx))? + }; + match self.poll_bounds_update(cx, &left_data) { + Poll::Ready(Ok(())) => {} + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => return Poll::Pending, + } + match self.partition_build_side(left_data) { + Ok(StatefulStreamResult::Continue) => continue, + Ok(StatefulStreamResult::Ready(Some(batch))) => { + // "[spill-join] poll_next yielding initial batch: rows={}", + // batch.num_rows() + // ); + return Poll::Ready(Some(Ok(batch))); + } + Ok(StatefulStreamResult::Ready(None)) => { + return Poll::Ready(None) + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } + PartitionedHashJoinState::ProcessPartition(partition_state) => { + // Emit a zero-row placeholder once in multi-output mode to satisfy downstream schedulers + if self.num_partitions > 1 && !self.placeholder_emitted { + self.placeholder_emitted = true; + let empty = RecordBatch::new_empty(self.schema.clone()); + // "[spill-join] Emitting placeholder empty batch for partition {}", + // build_index + // ); + return Poll::Ready(Some(Ok(empty))); + } + match self.process_partition(cx, &partition_state) { + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { + // "[spill-join] poll_next yielding process batch: rows={} (state partition={})", + // batch.num_rows(), build_index + // ); + return Poll::Ready(Some(Ok(batch))); + } + Poll::Ready(Ok(StatefulStreamResult::Ready(None))) => { + return Poll::Ready(None); + } + Poll::Ready(Ok(StatefulStreamResult::Continue)) => { + continue; + } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => return Poll::Pending, + } + } + #[cfg(feature = "hybrid_hash_join_scheduler")] + PartitionedHashJoinState::WaitingForProbe => { + if self.pending_partitions.is_empty() { + if self.probe_scheduler_waiting_for_stream.is_empty() { + self.state = PartitionedHashJoinState::HandleUnmatchedRows; + continue; + } + return Poll::Pending; + } else { + self.transition_to_next_partition(); + continue; + } + } + PartitionedHashJoinState::HandleUnmatchedRows => { + match self.handle_unmatched_rows(cx) { + Poll::Ready(Ok(StatefulStreamResult::Ready(Some(batch)))) => { + // "[spill-join] poll_next yielding unmatched batch: rows={}", + // batch.num_rows() + // ); + return Poll::Ready(Some(Ok(batch))); + } + Poll::Ready(Ok(StatefulStreamResult::Ready(None))) => { + return Poll::Ready(None); + } + Poll::Ready(Ok(StatefulStreamResult::Continue)) => { + continue; + } + Poll::Ready(Err(e)) => return Poll::Ready(Some(Err(e))), + Poll::Pending => return Poll::Pending, + } + } + PartitionedHashJoinState::Completed => return Poll::Ready(None), + } + } + } +} + +#[cfg(all(test, feature = "hybrid_hash_join_scheduler"))] +mod scheduler_tests { + use super::*; + use crate::metrics::ExecutionPlanMetricsSet; + use crate::stream::RecordBatchStreamAdapter; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_execution::memory_pool::MemoryConsumer; + use datafusion_execution::runtime_env::RuntimeEnv; + use futures::{stream, task::noop_waker}; + use parking_lot::Mutex; + use std::sync::atomic::AtomicUsize; + use std::sync::Arc; + use std::task::Context as StdContext; + + fn test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, true)])) + } + + fn test_batch(schema: &SchemaRef, values: &[i32]) -> RecordBatch { + let array: ArrayRef = Arc::new(Int32Array::from(values.to_vec())); + RecordBatch::try_new(schema.clone(), vec![array]).unwrap() + } + + fn build_join_left_data( + batch: RecordBatch, + runtime_env: &Arc, + ) -> JoinLeftData { + let hash_map: Box = + Box::new(JoinHashMapU32::with_capacity(0)); + let reservation = MemoryConsumer::new("left") + .with_can_spill(true) + .register(&runtime_env.memory_pool); + JoinLeftData::new( + hash_map, + batch.clone(), + Arc::new(vec![batch]), + vec![], + Mutex::new(BooleanBufferBuilder::new(0)), + AtomicUsize::new(0), + reservation, + None, + ) + } + + fn make_test_stream( + num_partitions: usize, + max_streams: usize, + ) -> PartitionedHashJoinStream { + let runtime_env = Arc::new(RuntimeEnv::default()); + let schema = test_schema(); + let build_batch = RecordBatch::new_empty(schema.clone()); + let left_data = build_join_left_data(build_batch, &runtime_env); + let left_fut = OnceFut::new(async move { Ok(left_data) }); + let metrics = ExecutionPlanMetricsSet::new(); + let join_metrics = BuildProbeJoinMetrics::new(0, &metrics); + let probe_spill_metrics = SpillMetrics::new(&metrics, 0); + let build_spill_metrics = SpillMetrics::new(&metrics, 0); + let right_stream: SendableRecordBatchStream = Box::pin( + RecordBatchStreamAdapter::new(schema.clone(), stream::empty()), + ); + let memory_reservation = MemoryConsumer::new("top") + .with_can_spill(true) + .register(&runtime_env.memory_pool); + + let mut stream = PartitionedHashJoinStream::new( + 0, + schema.clone(), + vec![], + vec![], + None, + JoinType::Inner, + right_stream, + left_fut, + RandomState::with_seeds(0, 0, 0, 0), + join_metrics, + probe_spill_metrics, + build_spill_metrics, + vec![], + NullEquality::NullEqualsNothing, + 1024, + num_partitions, + num_partitions, + 1024, + memory_reservation, + runtime_env, + schema.clone(), + schema, + false, + None, + ) + .unwrap(); + stream.probe_scheduler_max_streams = max_streams; + stream.pending_partitions.clear(); + for pending in stream.partition_pending.iter_mut() { + *pending = false; + } + stream + } + + fn add_spill_file( + stream: &mut PartitionedHashJoinStream, + part_id: usize, + batch: &RecordBatch, + ) -> Result<()> { + let mut writer = stream + .probe_spill_manager + .create_in_progress_file("test_spill")?; + writer.append_batch(batch)?; + let file = writer.finish()?.expect("spill file"); + stream.probe_states[part_id].spill_files.push_back(file); + Ok(()) + } + + fn descriptor_for(partition: usize) -> PartitionDescriptor { + PartitionDescriptor { + build_index: partition, + root_index: partition, + generation: 0, + radix_bits: 0, + hash_prefix: partition as u64, + spilled_bytes: 0, + spilled_rows: 0, + } + } + + async fn poll_task_status( + stream: &mut PartitionedHashJoinStream, + desc: &PartitionDescriptor, + ) -> ProbeTaskStatus { + let waker = noop_waker(); + for _ in 0..4096 { + let mut cx = StdContext::from_waker(&waker); + let status = stream + .poll_probe_stage_task(&mut cx, desc) + .expect("poll should succeed"); + if matches!(status, ProbeTaskStatus::Pending) { + tokio::task::yield_now().await; + continue; + } + return status; + } + panic!("probe task stuck in pending state"); + } + + async fn poll_probe_data_until_ready( + stream: &mut PartitionedHashJoinStream, + part_id: usize, + ) -> ProbeDataPoll { + let waker = noop_waker(); + for _ in 0..4096 { + let mut cx = StdContext::from_waker(&waker); + let status = stream + .poll_probe_data_for_partition(part_id, &mut cx) + .expect("poll probe data"); + if matches!(status, ProbeDataPoll::Pending) { + tokio::task::yield_now().await; + continue; + } + return status; + } + panic!("probe data did not become ready"); + } + + #[tokio::test] + async fn probe_tasks_wait_for_stream_slots() -> Result<()> { + let mut stream = make_test_stream(2, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[1]); + add_spill_file(&mut stream, 0, &batch)?; + add_spill_file(&mut stream, 1, &batch)?; + + let desc1 = descriptor_for(1); + stream.partition_descriptors[0] = Some(descriptor_for(0)); + stream.partition_descriptors[1] = Some(desc1.clone()); + stream.partition_pending[0] = false; + stream.partition_pending[1] = false; + stream.current_partition = Some(1); + + // Simulate another partition already holding the single stream slot. + stream.probe_scheduler_active_streams = stream.probe_scheduler_max_streams; + + let status = poll_task_status(&mut stream, &desc1).await; + assert!(matches!(status, ProbeTaskStatus::WaitingForStream)); + stream.enqueue_stream_waiter(desc1.build_index); + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 1); + + stream.probe_states[0].pending_stream = None; + stream.release_probe_stream_slot(); + assert_eq!(stream.probe_scheduler_active_streams, 0); + assert!(stream.probe_scheduler_waiting_for_stream.is_empty()); + let desc = stream.pending_partitions.pop_front().unwrap(); + assert_eq!(desc.build_index, 1); + stream.partition_pending[desc.build_index] = false; + Ok(()) + } + + #[tokio::test] + async fn probe_task_resumes_after_slot_available() -> Result<()> { + let mut stream = make_test_stream(2, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[10, 20]); + add_spill_file(&mut stream, 1, &batch)?; + + let desc1 = descriptor_for(1); + stream.partition_descriptors[1] = Some(desc1.clone()); + stream.partition_pending[1] = false; + stream.current_partition = Some(1); + + // Ensure there's no active stream yet. + assert_eq!(stream.probe_scheduler_active_streams, 0); + + let status = poll_task_status(&mut stream, &desc1).await; + assert!(matches!(status, ProbeTaskStatus::Ready)); + assert!(stream.probe_states[1].active_batch.is_some()); + assert_eq!(stream.probe_scheduler_active_streams, 1); + + // Mark the active batch as consumed and continue polling to drain the spill stream. + stream.probe_states[1].active_batch = None; + let mut status = poll_probe_data_until_ready(&mut stream, 1).await; + if matches!(status, ProbeDataPoll::Ready) { + stream.probe_states[1].active_batch = None; + status = poll_probe_data_until_ready(&mut stream, 1).await; + } + assert!(matches!(status, ProbeDataPoll::Finished)); + assert_eq!(stream.probe_scheduler_active_streams, 0); + Ok(()) + } + + #[tokio::test] + async fn probe_tasks_wait_queue_multiple() -> Result<()> { + let mut stream = make_test_stream(3, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[5]); + for part in 0..3 { + add_spill_file(&mut stream, part, &batch)?; + let desc = descriptor_for(part); + stream.partition_descriptors[part] = Some(desc); + stream.partition_pending[part] = false; + } + + // Partition 0 currently holds the only stream slot. + stream.probe_scheduler_active_streams = stream.probe_scheduler_max_streams; + + // Partitions 1 and 2 must wait for a stream slot. + for part in [1, 2] { + stream.enqueue_stream_waiter(part); + } + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 2); + + // Releasing the stream should enqueue partition 1 for processing. + stream.release_probe_stream_slot(); + assert_eq!(stream.probe_scheduler_active_streams, 0); + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 1); + let desc = stream.pending_partitions.pop_front().unwrap(); + assert_eq!(desc.build_index, 1); + stream.partition_pending[desc.build_index] = false; + + // Simulate partition 1 holding the stream slot and then finishing. + stream.probe_scheduler_active_streams = stream.probe_scheduler_max_streams; + stream.probe_states[1].pending_stream = None; + stream.release_probe_stream_slot(); + let desc = stream.pending_partitions.pop_front().unwrap(); + assert_eq!(desc.build_index, 2); + stream.partition_pending[desc.build_index] = false; + Ok(()) + } + + #[tokio::test] + async fn wait_queue_blocks_state_progression() -> Result<()> { + let mut stream = make_test_stream(2, 1); + let schema = stream.probe_schema.clone(); + let batch = test_batch(&schema, &[7]); + for part in 0..2 { + add_spill_file(&mut stream, part, &batch)?; + let desc = descriptor_for(part); + stream.partition_descriptors[part] = Some(desc.clone()); + stream.pending_partitions.push_back(desc); + stream.partition_pending[part] = true; + } + + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 0 + )); + + // Both partitions end up waiting on a limited stream slot. + stream.enqueue_stream_waiter(0); + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 1 + )); + + stream.enqueue_stream_waiter(1); + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::WaitingForProbe + )); + assert!(stream.pending_partitions.is_empty()); + assert_eq!(stream.probe_scheduler_waiting_for_stream.len(), 2); + + // Releasing a stream slot wakes the earliest waiter and resumes partition 0. + stream.probe_scheduler_active_streams = 0; + stream.wake_stream_waiter(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 0 + )); + + // Simulate finishing partition 0, which should put the stream back into waiting mode + // because partition 1 is still throttled. + stream.current_partition = None; + stream.transition_to_next_partition(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::WaitingForProbe + )); + + // Another wake picks up the remaining partition. + stream.wake_stream_waiter(); + assert!(matches!( + stream.state, + PartitionedHashJoinState::ProcessPartition(ProcessPartitionState { + descriptor: ref desc + }) if desc.build_index == 1 + )); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/hash_join/scheduler.rs b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs new file mode 100644 index 000000000000..79217e759f75 --- /dev/null +++ b/datafusion/physical-plan/src/joins/hash_join/scheduler.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Experimental Hybrid Hash Join scheduler abstractions. + +#![cfg(feature = "hybrid_hash_join_scheduler")] + +use std::collections::VecDeque; +use std::sync::Arc; +use std::task::Context; + +use arrow::{array::ArrayRef, record_batch::RecordBatch}; + +use crate::joins::hash_join::exec::JoinLeftData; +use crate::joins::hash_join::partitioned::{ + PartitionDescriptor, PartitionedHashJoinStream, ProbePartition, +}; +use crate::joins::utils::StatefulStreamResult; +use crate::SendableRecordBatchStream; + +use datafusion_common::{internal_datafusion_err, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; + +use crate::joins::join_hash_map::JoinHashMapOffset; +use crate::spill::in_progress_spill_file::InProgressSpillFile; + +/// Minimal scheduler capable of running build / probe / finalize tasks. +pub(super) struct HybridTaskScheduler { + ready_queue: VecDeque, +} + +impl HybridTaskScheduler { + pub fn new() -> Self { + Self { + ready_queue: VecDeque::new(), + } + } + + pub fn push_task(&mut self, task: SchedulerTask) { + self.ready_queue.push_back(task); + } + + pub fn pop_task(&mut self) -> Option { + self.ready_queue.pop_front() + } + + pub fn len(&self) -> usize { + self.ready_queue.len() + } + + pub fn with_build_task(build_data: Arc) -> Self { + let mut scheduler = Self::new(); + scheduler + .ready_queue + .push_back(SchedulerTask::Build(BuildStageTask::new(build_data))); + scheduler + } + + pub fn run_until_build_finished( + &mut self, + stream: &mut PartitionedHashJoinStream, + ) -> Result>> { + while let Some(task) = self.ready_queue.pop_front() { + match task.poll(stream, None)? { + TaskPoll::ProbeReady(_) => continue, + TaskPoll::Pending(task) => self.ready_queue.push_back(task), + TaskPoll::BuildFinished(result) => return Ok(result), + TaskPoll::YieldProbe { task, .. } => self.ready_queue.push_back(task), + TaskPoll::ProbeFinished(_) => continue, + } + } + Err(internal_datafusion_err!( + "scheduler queue exhausted without producing build output" + )) + } +} + +pub(super) enum SchedulerTask { + Build(BuildStageTask), + Probe(ProbeStageTask), +} + +pub(super) enum TaskPoll { + ProbeReady(PartitionDescriptor), + Pending(SchedulerTask), + BuildFinished(StatefulStreamResult>), + /// Probe task yielded without producing output (e.g. waiting on IO). + YieldProbe { + task: SchedulerTask, + descriptor: PartitionDescriptor, + }, + ProbeFinished(PartitionDescriptor), +} + +impl SchedulerTask { + pub(super) fn poll( + self, + stream: &mut PartitionedHashJoinStream, + cx: Option<&mut Context<'_>>, + ) -> Result { + match self { + SchedulerTask::Build(task) => match task.poll(stream)? { + BuildTaskEvent::Pending(next_state) => { + Ok(TaskPoll::Pending(SchedulerTask::Build(next_state))) + } + BuildTaskEvent::Finished(result) => Ok(TaskPoll::BuildFinished(result)), + }, + SchedulerTask::Probe(task) => { + let cx = cx.expect("probe task requires runtime context"); + let descriptor = task.descriptor().clone(); + match task.poll(stream, cx)? { + ProbeTaskEvent::Pending(next_task) => { + Ok(TaskPoll::Pending(SchedulerTask::Probe(next_task))) + } + ProbeTaskEvent::Ready => Ok(TaskPoll::ProbeReady(descriptor)), + ProbeTaskEvent::NeedStream(next_task) => { + let wait_descriptor = next_task.descriptor().clone(); + Ok(TaskPoll::YieldProbe { + task: SchedulerTask::Probe(next_task), + descriptor: wait_descriptor, + }) + } + ProbeTaskEvent::Finished => Ok(TaskPoll::ProbeFinished(descriptor)), + } + } + } + } +} + +/// Build stage broken into multiple cooperative steps so the scheduler can interleave it. +struct BuildStageTask { + build_data: Option>, + step: BuildTaskStep, + warmup_remaining: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum BuildTaskStep { + Init, + Partitioning, + Finished, +} + +impl BuildStageTask { + fn new(build_data: Arc) -> Self { + Self { + build_data: Some(build_data), + step: BuildTaskStep::Init, + warmup_remaining: 2, // allow a couple of yields before heavy work + } + } + + fn poll(mut self, stream: &mut PartitionedHashJoinStream) -> Result { + match self.step { + BuildTaskStep::Init => { + if self.warmup_remaining > 0 { + self.warmup_remaining -= 1; + return Ok(BuildTaskEvent::Pending(self)); + } + self.step = BuildTaskStep::Partitioning; + Ok(BuildTaskEvent::Pending(self)) + } + BuildTaskStep::Partitioning => { + let build_data = self.build_data.take().ok_or_else(|| { + internal_datafusion_err!("build task missing input data") + })?; + let result = stream.partition_build_side_serial(build_data)?; + self.step = BuildTaskStep::Finished; + Ok(BuildTaskEvent::Finished(result)) + } + BuildTaskStep::Finished => { + Err(internal_datafusion_err!("build task already finished")) + } + } + } +} + +enum BuildTaskEvent { + Pending(BuildStageTask), + Finished(StatefulStreamResult>), +} + +pub(super) struct ProbePartitionState { + pub buffered: ProbePartition, + pub batch_position: usize, + pub buffered_rows: usize, + pub spilled_rows: usize, + pub consumed_rows: usize, + pub spill_in_progress: Option, + pub spill_files: VecDeque, + pub pending_stream: Option, + pub active_batch: Option, + pub active_values: Vec, + pub active_hashes: Vec, + pub active_offset: JoinHashMapOffset, + pub joined_probe_idx: Option, +} + +impl ProbePartitionState { + pub fn new() -> Self { + Self { + buffered: ProbePartition::new(), + batch_position: 0, + buffered_rows: 0, + spilled_rows: 0, + consumed_rows: 0, + spill_in_progress: None, + spill_files: VecDeque::new(), + pending_stream: None, + active_batch: None, + active_values: Vec::new(), + active_hashes: Vec::new(), + active_offset: (0, None), + joined_probe_idx: None, + } + } + + pub fn reset(&mut self) { + *self = Self::new(); + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum ProbeDataPoll { + Ready, + Pending, + NeedStream, + Finished, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ProbeTaskState { + Init, + Ready, + Finished, +} + +pub(super) struct ProbeStageTask { + descriptor: PartitionDescriptor, + state: ProbeTaskState, +} + +impl ProbeStageTask { + pub fn new(descriptor: PartitionDescriptor) -> Self { + Self { + descriptor, + state: ProbeTaskState::Init, + } + } + + pub fn descriptor(&self) -> &PartitionDescriptor { + &self.descriptor + } + + fn poll( + mut self, + stream: &mut PartitionedHashJoinStream, + cx: &mut Context<'_>, + ) -> Result { + match self.state { + ProbeTaskState::Init => { + self.state = ProbeTaskState::Ready; + Ok(ProbeTaskEvent::Pending(self)) + } + ProbeTaskState::Ready => { + match stream + .poll_probe_data_for_partition(self.descriptor.build_index, cx)? + { + ProbeDataPoll::Ready => Ok(ProbeTaskEvent::Ready), + ProbeDataPoll::Pending => Ok(ProbeTaskEvent::Pending(self)), + ProbeDataPoll::NeedStream => Ok(ProbeTaskEvent::NeedStream(self)), + ProbeDataPoll::Finished => { + self.state = ProbeTaskState::Finished; + Ok(ProbeTaskEvent::Finished) + } + } + } + ProbeTaskState::Finished => Ok(ProbeTaskEvent::Finished), + } + } +} + +enum ProbeTaskEvent { + Pending(ProbeStageTask), + Ready, + NeedStream(ProbeStageTask), + Finished, +} diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 40dc4ac2e5d1..335ead1dedf0 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -160,6 +160,10 @@ impl SharedBoundsAccumulator { PartitionMode::Partitioned => { left_child.output_partitioning().partition_count() } + // For partitioned spillable, use the same logic as regular partitioned + PartitionMode::PartitionedSpillable => { + left_child.output_partitioning().partition_count() + } // Default value, will be resolved during optimization (does not exist once `execute()` is called; will be replaced by one of the other two) PartitionMode::Auto => unreachable!("PartitionMode::Auto should not be present at execution time. This is a bug in DataFusion, please report it!"), }; diff --git a/datafusion/physical-plan/src/joins/hash_join/stream.rs b/datafusion/physical-plan/src/joins/hash_join/stream.rs index 4484eeabd326..d840f31fcb44 100644 --- a/datafusion/physical-plan/src/joins/hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/hash_join/stream.rs @@ -35,8 +35,8 @@ use crate::{ joins::utils::{ adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_empty_build_side, build_batch_from_indices, - need_produce_result_in_final, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, - JoinHashMapType, StatefulStreamResult, + need_produce_result_in_final, uint32_to_uint64_indices, BuildProbeJoinMetrics, + ColumnIndex, JoinFilter, JoinHashMapType, StatefulStreamResult, }, RecordBatchStream, SendableRecordBatchStream, }; @@ -549,13 +549,24 @@ impl HashJoinStream { // Calculate range and perform alignment. // In case probe batch has been processed -- align all remaining rows. - let index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); - let index_alignment_range_end = if next_offset.is_none() { - state.batch.num_rows() + let batch_num_rows = state.batch.num_rows(); + let mut index_alignment_range_start = state.joined_probe_idx.map_or(0, |v| v + 1); + let mut index_alignment_range_end = if next_offset.is_none() { + batch_num_rows } else { last_joined_right_idx.map_or(0, |v| v + 1) }; + if index_alignment_range_start > batch_num_rows { + index_alignment_range_start = batch_num_rows; + } + if index_alignment_range_end > batch_num_rows { + index_alignment_range_end = batch_num_rows; + } + if index_alignment_range_end < index_alignment_range_start { + index_alignment_range_end = index_alignment_range_start; + } + let (left_indices, right_indices) = adjust_indices_by_join_type( left_indices, right_indices, @@ -564,12 +575,16 @@ impl HashJoinStream { self.right_side_ordered, )?; - let result = if self.join_type == JoinType::RightMark { + let result = if matches!( + self.join_type, + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark + ) { + let right_indices_u64 = uint32_to_uint64_indices(&right_indices); build_batch_from_indices( &self.schema, &state.batch, build_side.left_data.batch(), - &left_indices, + &right_indices_u64, &right_indices, &self.column_indices, JoinSide::Right, diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434..52d084bad4d2 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,6 +20,7 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; +pub use grace_hash_join::GraceHashJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; @@ -27,7 +28,9 @@ use parking_lot::Mutex; pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; +mod grace_hash_join; mod hash_join; + mod nested_loop_join; mod sort_merge_join; mod stream_join_utils; @@ -56,6 +59,8 @@ pub enum PartitionMode { /// mode(Partitioned/CollectLeft) is optimal based on statistics. It will /// also consider swapping the left and right inputs for the Join Auto, + /// Partitioned hash join that can spill to disk for large datasets + PartitionedSpillable, } /// Partitioning mode to use for symmetric hash join diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d392650f88dd..c12e6ab00121 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -1230,6 +1230,22 @@ where ) } +pub(crate) fn uint32_to_uint64_indices(indices: &UInt32Array) -> UInt64Array { + if indices.null_count() == 0 { + UInt64Array::from_iter_values(indices.values().iter().map(|v| *v as u64)) + } else { + let mut builder = UInt64Builder::with_capacity(indices.len()); + for i in 0..indices.len() { + if indices.is_null(i) { + builder.append_null(); + } else { + builder.append_value(indices.value(i) as u64); + } + } + builder.finish() + } +} + fn build_range_bitmap( range: &Range, input: &PrimitiveArray, @@ -1313,6 +1329,26 @@ pub(crate) struct BuildProbeJoinMetrics { pub(crate) build_input_rows: metrics::Count, /// Memory used by build-side in bytes pub(crate) build_mem_used: metrics::Gauge, + /// Number of spill files produced for the build side + pub(crate) build_spill_count: metrics::Count, + /// Total build-side bytes written to spill + pub(crate) build_spilled_bytes: metrics::Count, + /// Total build-side rows written to spill + pub(crate) build_spilled_rows: metrics::Count, + /// Number of spill files produced for the probe side + pub(crate) probe_spill_count: metrics::Count, + /// Total probe-side bytes written to spill + pub(crate) probe_spilled_bytes: metrics::Count, + /// Total probe-side rows written to spill + pub(crate) probe_spilled_rows: metrics::Count, + /// Number of times recursive repartitioning was triggered + pub(crate) recursive_repartition_events: metrics::Count, + /// Total number of child partitions materialized by recursion + pub(crate) recursive_partitions_created: metrics::Count, + /// Maximum recursion depth reached + pub(crate) recursive_partition_depth: metrics::Gauge, + /// Maximum fan-out applied during recursive repartitioning + pub(crate) recursive_repartition_fanout: metrics::Gauge, /// Total time for joining probe-side batches to the build-side batches pub(crate) join_time: metrics::Time, /// Number of batches consumed by probe-side of this operator @@ -1358,6 +1394,26 @@ impl BuildProbeJoinMetrics { let build_mem_used = MetricBuilder::new(metrics).gauge("build_mem_used", partition); + let build_spill_count = + MetricBuilder::new(metrics).counter("build_spill_count", partition); + let build_spilled_bytes = + MetricBuilder::new(metrics).counter("build_spilled_bytes", partition); + let build_spilled_rows = + MetricBuilder::new(metrics).counter("build_spilled_rows", partition); + let probe_spill_count = + MetricBuilder::new(metrics).counter("probe_spill_count", partition); + let probe_spilled_bytes = + MetricBuilder::new(metrics).counter("probe_spilled_bytes", partition); + let probe_spilled_rows = + MetricBuilder::new(metrics).counter("probe_spilled_rows", partition); + let recursive_repartition_events = MetricBuilder::new(metrics) + .counter("recursive_repartition_events", partition); + let recursive_partitions_created = MetricBuilder::new(metrics) + .counter("recursive_partitions_created", partition); + let recursive_partition_depth = + MetricBuilder::new(metrics).gauge("recursive_partition_depth", partition); + let recursive_repartition_fanout = + MetricBuilder::new(metrics).gauge("recursive_repartition_fanout", partition); let input_batches = MetricBuilder::new(metrics).counter("input_batches", partition); @@ -1372,6 +1428,16 @@ impl BuildProbeJoinMetrics { build_input_batches, build_input_rows, build_mem_used, + build_spill_count, + build_spilled_bytes, + build_spilled_rows, + probe_spill_count, + probe_spilled_bytes, + probe_spilled_rows, + recursive_repartition_events, + recursive_partitions_created, + recursive_partition_depth, + recursive_repartition_fanout, join_time, input_batches, input_rows, @@ -1663,7 +1729,7 @@ pub fn update_hash( hashes_buffer: &mut Vec, deleted_offset: usize, fifo_hashmap: bool, -) -> Result<()> { +) -> Result> { // evaluate the keys let keys_values = on .iter() @@ -1688,7 +1754,7 @@ pub fn update_hash( hash_map.update_from_iter(Box::new(hash_values_iter), deleted_offset); } - Ok(()) + Ok(keys_values) } pub(super) fn equal_rows_arr( diff --git a/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs b/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs new file mode 100644 index 000000000000..bba0f6f95625 --- /dev/null +++ b/datafusion/physical-plan/src/spill/in_memory_spill_buffer.rs @@ -0,0 +1,46 @@ +use crate::memory::MemoryStream; +use crate::spill::spill_manager::GetSlicedSize; +use arrow::array::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::SendableRecordBatchStream; +use std::sync::Arc; + +#[derive(Debug)] +pub struct InMemorySpillBuffer { + batches: Vec, + total_bytes: usize, +} + +impl InMemorySpillBuffer { + pub fn from_batch(batch: &RecordBatch) -> Result { + Ok(Self { + batches: vec![batch.clone()], + total_bytes: batch.get_sliced_size()?, + }) + } + + pub fn from_batches(batches: &[RecordBatch]) -> Result { + let mut total_bytes = 0; + let mut owned = Vec::with_capacity(batches.len()); + for b in batches { + total_bytes += b.get_sliced_size()?; + owned.push(b.clone()); + } + Ok(Self { + batches: owned, + total_bytes, + }) + } + + pub fn as_stream( + self: Arc, + schema: Arc, + ) -> Result { + let stream = MemoryStream::try_new(self.batches.clone(), schema, None)?; + Ok(Box::pin(stream)) + } + + pub fn size(&self) -> usize { + self.total_bytes + } +} diff --git a/datafusion/physical-plan/src/spill/mod.rs b/datafusion/physical-plan/src/spill/mod.rs index fab62bff840f..782100e6d4cf 100644 --- a/datafusion/physical-plan/src/spill/mod.rs +++ b/datafusion/physical-plan/src/spill/mod.rs @@ -17,6 +17,7 @@ //! Defines the spilling functions +pub(crate) mod in_memory_spill_buffer; pub(crate) mod in_progress_spill_file; pub(crate) mod spill_manager; @@ -376,16 +377,17 @@ mod tests { use crate::common::collect; use crate::metrics::ExecutionPlanMetricsSet; use crate::metrics::SpillMetrics; - use crate::spill::spill_manager::SpillManager; + use crate::spill::spill_manager::{SpillLocation, SpillManager}; use crate::test::build_table_i32; use arrow::array::{ArrayRef, Float64Array, Int32Array, ListArray, StringArray}; use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::Result; - use datafusion_execution::runtime_env::RuntimeEnv; + use datafusion_execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; use futures::StreamExt as _; + use datafusion_execution::memory_pool::{FairSpillPool, MemoryPool}; use std::sync::Arc; #[tokio::test] @@ -426,6 +428,71 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_batch_spill_to_memory_and_disk_and_read() -> Result<()> { + let schema: SchemaRef = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from_iter_values(0..1000)), + Arc::new(Int32Array::from_iter_values(1000..2000)), + ], + )?; + + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from_iter_values(2000..4000)), + Arc::new(Int32Array::from_iter_values(4000..6000)), + ], + )?; + + let num_rows = batch1.num_rows() + batch2.num_rows(); + let batches = vec![batch1, batch2]; + + // --- create small memory pool (simulate memory pressure) --- + let memory_limit_bytes = 20 * 1024; // 20KB + let memory_pool: Arc = + Arc::new(FairSpillPool::new(memory_limit_bytes)); + + // Construct SpillManager + let env = RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build_arc()?; + let metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0); + let spill_manager = SpillManager::new(env, metrics, Arc::clone(&schema)); + + let results = spill_manager.spill_batches_auto(&batches, "TestAutoSpill")?; + assert_eq!(results.len(), 2); + + let mem_count = results + .iter() + .filter(|r| matches!(r, SpillLocation::Memory(_))) + .count(); + let disk_count = results + .iter() + .filter(|r| matches!(r, SpillLocation::Disk(_))) + .count(); + assert!(mem_count >= 1); + assert!(disk_count >= 1); + + let spilled_rows = spill_manager.metrics.spilled_rows.value(); + assert_eq!(spilled_rows, num_rows); + + for spill in results { + let stream = spill_manager.load_spilled_batch(&spill)?; + let collected = collect(stream).await?; + assert!(!collected.is_empty()); + assert_eq!(collected[0].schema(), schema); + } + + Ok(()) + } + #[tokio::test] async fn test_batch_spill_and_read_dictionary_arrays() -> Result<()> { // See https://github.com/apache/datafusion/issues/4658 diff --git a/datafusion/physical-plan/src/spill/spill_manager.rs b/datafusion/physical-plan/src/spill/spill_manager.rs index ad23bd66a021..27176ff02527 100644 --- a/datafusion/physical-plan/src/spill/spill_manager.rs +++ b/datafusion/physical-plan/src/spill/spill_manager.rs @@ -21,14 +21,16 @@ use arrow::array::StringViewArray; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_execution::runtime_env::RuntimeEnv; +use std::slice; use std::sync::Arc; -use datafusion_common::{config::SpillCompression, Result}; +use datafusion_common::{config::SpillCompression, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::SendableRecordBatchStream; use super::{in_progress_spill_file::InProgressSpillFile, SpillReaderStream}; use crate::coop::cooperative; +use crate::spill::in_memory_spill_buffer::InMemorySpillBuffer; use crate::{common::spawn_buffered, metrics::SpillMetrics}; /// The `SpillManager` is responsible for the following tasks: @@ -168,6 +170,60 @@ impl SpillManager { Ok(file.map(|f| (f, max_record_batch_size))) } + /// Automatically decides whether to spill the given RecordBatch to memory or disk, + /// depending on available memory pool capacity. + pub(crate) fn spill_batch_auto( + &self, + batch: &RecordBatch, + request_msg: &str, + ) -> Result { + // let Some(file) = self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? else { + // return Err(DataFusionError::Execution( + // "failed to spill batch to disk".into(), + // )); + // }; + // Ok(SpillLocation::Disk(Arc::new(file))) + // // + let size = batch.get_sliced_size()?; + + // Check current memory usage and total limit from the runtime memory pool + let used = self.env.memory_pool.reserved(); + let limit = match self.env.memory_pool.memory_limit() { + datafusion_execution::memory_pool::MemoryLimit::Finite(l) => l, + _ => usize::MAX, + }; + + // If there's enough memory (with a safety margin), keep it in memory + if used + size * 3 / 2 <= limit { + let buf = Arc::new(InMemorySpillBuffer::from_batch(batch)?); + self.metrics.spilled_bytes.add(size); + self.metrics.spilled_rows.add(batch.num_rows()); + Ok(SpillLocation::Memory(buf)) + } else { + // Otherwise spill to disk using the existing SpillManager logic + let Some(file) = + self.spill_record_batch_and_finish(slice::from_ref(batch), request_msg)? + else { + return Err(DataFusionError::Execution( + "failed to spill batch to disk".into(), + )); + }; + Ok(SpillLocation::Disk(Arc::new(file))) + } + } + + pub fn spill_batches_auto( + &self, + batches: &[RecordBatch], + request_msg: &str, + ) -> Result> { + let mut result = Vec::with_capacity(batches.len()); + for batch in batches { + result.push(self.spill_batch_auto(batch, request_msg)?); + } + Ok(result) + } + /// Reads a spill file as a stream. The file must be created by the current `SpillManager`. /// This method will generate output in FIFO order: the batch appended first /// will be read first. @@ -182,6 +238,36 @@ impl SpillManager { Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) } + + pub fn read_spill_as_stream_ref( + &self, + spill_file_path: &RefCountedTempFile, + ) -> Result { + let stream = Box::pin(cooperative(SpillReaderStream::new( + Arc::clone(&self.schema), + spill_file_path.clone_refcounted()?, + ))); + + Ok(spawn_buffered(stream, self.batch_read_buffer_capacity)) + } + + pub fn load_spilled_batch( + &self, + spill: &SpillLocation, + ) -> Result { + match spill { + SpillLocation::Memory(buf) => { + Ok(Arc::clone(&buf).as_stream(Arc::clone(&self.schema))?) + } + SpillLocation::Disk(file) => self.read_spill_as_stream_ref(file), + } + } +} + +#[derive(Debug, Clone)] +pub enum SpillLocation { + Memory(Arc), + Disk(Arc), } pub(crate) trait GetSlicedSize { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 70d6caf7642b..07b71ff159ca 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -732,6 +732,7 @@ message PhysicalPlanNode { GenerateSeriesNode generate_series = 33; SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; + GraceHashJoinExecNode grace_hash_join = 36; } } @@ -1074,6 +1075,16 @@ message HashJoinExecNode { repeated uint32 projection = 9; } +message GraceHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + datafusion_common.JoinType join_type = 4; + datafusion_common.NullEquality null_equality = 5; + JoinFilter filter = 6; + repeated uint32 projection = 7; +} + enum StreamPartitionMode { SINGLE_PARTITION = 0; PARTITIONED_EXEC = 1; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 83f662e61112..4d7f8241c710 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7717,6 +7717,208 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for GraceHashJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.null_equality != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + if !self.projection.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GraceHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GraceHashJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "null_equality", + "nullEquality", + "filter", + "projection", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + NullEquality, + Filter, + Projection, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), + "filter" => Ok(GeneratedField::Filter), + "projection" => Ok(GeneratedField::Projection), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GraceHashJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GraceHashJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut null_equality__ = None; + let mut filter__ = None; + let mut projection__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); + } + null_equality__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(GraceHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), + filter: filter__, + projection: projection__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GraceHashJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for GroupingSetNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17022,6 +17224,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::MemoryScan(v) => { struct_ser.serialize_field("memoryScan", v)?; } + physical_plan_node::PhysicalPlanType::GraceHashJoin(v) => { + struct_ser.serialize_field("graceHashJoin", v)?; + } } } struct_ser.end() @@ -17087,6 +17292,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin", "memory_scan", "memoryScan", + "grace_hash_join", + "graceHashJoin", ]; #[allow(clippy::enum_variant_names)] @@ -17125,6 +17332,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { GenerateSeries, SortMergeJoin, MemoryScan, + GraceHashJoin, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17180,6 +17388,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "generateSeries" | "generate_series" => Ok(GeneratedField::GenerateSeries), "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), + "graceHashJoin" | "grace_hash_join" => Ok(GeneratedField::GraceHashJoin), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17438,6 +17647,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("memoryScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::MemoryScan) +; + } + GeneratedField::GraceHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("graceHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GraceHashJoin) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cc19add6fbe9..d5520c6843b6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1055,7 +1055,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub physical_plan_type: ::core::option::Option, } @@ -1133,6 +1133,8 @@ pub mod physical_plan_node { SortMergeJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "35")] MemoryScan(super::MemoryScanExecNode), + #[prost(message, tag = "36")] + GraceHashJoin(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1636,6 +1638,23 @@ pub struct HashJoinExecNode { pub projection: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct GraceHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "5")] + pub null_equality: i32, + #[prost(message, optional, tag = "6")] + pub filter: ::core::option::Option, + #[prost(uint32, repeated, tag = "7")] + pub projection: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e577de5b1d0e..adcc2d2dffff 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -77,8 +77,8 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion::physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, - SymmetricHashJoinExec, + CrossJoinExec, GraceHashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, + StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; @@ -214,6 +214,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), + PhysicalPlanType::GraceHashJoin(grace_hash_join) => self + .try_into_grace_hash_join_physical_plan( + grace_hash_join, + ctx, + runtime, + extension_codec, + ), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, @@ -365,6 +372,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_grace_hash_join_exec( + exec, + extension_codec, + ); + } + if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, @@ -1266,6 +1280,116 @@ impl protobuf::PhysicalPlanNode { )?)) } + fn try_into_grace_hash_join_physical_plan( + &self, + grace_join: &protobuf::GraceHashJoinExecNode, + ctx: &SessionContext, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left = into_physical_plan(&grace_join.left, ctx, runtime, extension_codec)?; + let right = into_physical_plan(&grace_join.right, ctx, runtime, extension_codec)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let on = grace_join + .on + .iter() + .map(|col| { + let left_expr = parse_physical_expr( + col.left.as_ref().ok_or_else(|| { + proto_error("GraceHashJoinExecNode missing left expr") + })?, + ctx, + left_schema.as_ref(), + extension_codec, + )?; + let right_expr = parse_physical_expr( + col.right.as_ref().ok_or_else(|| { + proto_error("GraceHashJoinExecNode missing right expr") + })?, + ctx, + right_schema.as_ref(), + extension_codec, + )?; + Ok((left_expr, right_expr)) + }) + .collect::>()?; + let join_type = + protobuf::JoinType::try_from(grace_join.join_type).map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode with unknown JoinType {}", + grace_join.join_type + )) + })?; + let null_equality = protobuf::NullEquality::try_from(grace_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode with unknown NullEquality {}", + grace_join.null_equality + )) + })?; + let filter = grace_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + ctx, + &schema, + extension_codec, + )?; + let column_indices = f + .column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side).map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode message with JoinSide in Filter {}", + i.side + )) + })?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; + + Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + let projection = if !grace_join.projection.is_empty() { + Some( + grace_join + .projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + + Ok(Arc::new(GraceHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + projection, + null_equality.into(), + )?)) + } + fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, @@ -2202,6 +2326,7 @@ impl protobuf::PhysicalPlanNode { PartitionMode::CollectLeft => protobuf::PartitionMode::CollectLeft, PartitionMode::Partitioned => protobuf::PartitionMode::Partitioned, PartitionMode::Auto => protobuf::PartitionMode::Auto, + PartitionMode::PartitionedSpillable => protobuf::PartitionMode::Partitioned, }; Ok(protobuf::PhysicalPlanNode { @@ -2222,6 +2347,75 @@ impl protobuf::PhysicalPlanNode { }) } + fn try_from_grace_hash_join_exec( + exec: &GraceHashJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on: Vec = exec + .on() + .iter() + .map(|tuple| { + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), + }) + }) + .collect::>()?; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GraceHashJoin(Box::new( + protobuf::GraceHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + null_equality: null_equality.into(), + filter, + projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), + }, + ))), + }) + } + fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, extension_codec: &dyn PhysicalExtensionCodec, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a5357a132eef..c4fdaa78dba2 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -75,8 +75,8 @@ use datafusion::physical_plan::expressions::{ }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::{ - HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec, - StreamJoinPartitionMode, SymmetricHashJoinExec, + GraceHashJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, + SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion::physical_plan::placeholder_row::PlaceholderRowExec; @@ -284,6 +284,41 @@ fn roundtrip_hash_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_grace_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Arc::new(Column::new("col", schema_left.index_of("col")?)) as _, + Arc::new(Column::new("col", schema_right.index_of("col")?)) as _, + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + roundtrip_test(Arc::new(GraceHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + None, + NullEquality::NullEqualsNothing, + )?))?; + } + Ok(()) +} + #[test] fn roundtrip_nested_loop_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false);