From 8e96e002781f2727a9efea58a6db66cf55442bf2 Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Fri, 14 Nov 2025 19:49:21 +0200 Subject: [PATCH 1/2] Enable grace hash join under config --- datafusion/common/src/config.rs | 5 + datafusion/execution/src/config.rs | 11 + .../physical-optimizer/src/join_selection.rs | 100 +++++--- .../src/joins/grace_hash_join/exec.rs | 133 +++-------- .../src/joins/grace_hash_join/stream.rs | 151 +++--------- datafusion/physical-plan/src/joins/mod.rs | 1 + datafusion/proto/proto/datafusion.proto | 11 + datafusion/proto/src/generated/pbjson.rs | 216 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 21 +- datafusion/proto/src/physical_plan/mod.rs | 195 +++++++++++++++- .../tests/cases/roundtrip_physical_plan.rs | 35 +++ 11 files changed, 620 insertions(+), 259 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 6abb2f5c6d3ca..876d0e9b57745 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -761,6 +761,11 @@ config_namespace! { /// using the provided `target_partitions` level pub repartition_joins: bool, default = true + /// When set to true, use grace hash join operator instead of hash joins. + /// Grace hash join operator which repartitions both inputs to disk before performing the join. + /// This trades additional IO for predictable memory usage on very large joins. + pub enable_grace_hash_join: bool, default = false + /// Should DataFusion allow symmetric hash joins for unbounded data sources even when /// its inputs do not have any ordering or filtering If the flag is not enabled, /// the SymmetricHashJoin operator will be unable to prune its internal buffers, diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 491b1aca69ea1..ea7a54eb8b87e 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -235,6 +235,11 @@ impl SessionConfig { self.options.optimizer.repartition_joins } + /// Should spillable hash joins be executed via the Grace hash join operator? + pub fn enable_grace_hash_join(&self) -> bool { + self.options.optimizer.enable_grace_hash_join + } + /// Are aggregates repartitioned during execution? pub fn repartition_aggregations(&self) -> bool { self.options.optimizer.repartition_aggregations @@ -298,6 +303,12 @@ impl SessionConfig { self } + /// Enables or disables the Grace hash join operator for spillable hash joins + pub fn with_enable_grace_hash_join(mut self, enabled: bool) -> Self { + self.options_mut().optimizer.enable_grace_hash_join = enabled; + self + } + /// Enables or disables the use of repartitioning for aggregations to improve parallelism pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self { self.options_mut().optimizer.repartition_aggregations = enabled; diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index c2cfca681f667..54b65450374dc 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -27,15 +27,15 @@ use crate::PhysicalOptimizerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, JoinSide, JoinType}; +use datafusion_common::{DataFusionError, JoinSide, JoinType, internal_err}; use datafusion_expr_common::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::execution_plan::EmissionType; use datafusion_physical_plan::joins::utils::ColumnIndex; use datafusion_physical_plan::joins::{ - CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, - StreamJoinPartitionMode, SymmetricHashJoinExec, + CrossJoinExec, HashJoinExec, GraceHashJoinExec, NestedLoopJoinExec, PartitionMode, + StreamJoinPartitionMode, SymmetricHashJoinExec }; use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use std::sync::Arc; @@ -134,12 +134,14 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; + let enable_grace_hash_join = config.enable_grace_hash_join; new_plan .transform_up(|plan| { statistical_join_selection_subrule( plan, collect_threshold_byte_size, collect_threshold_num_rows, + enable_grace_hash_join, ) }) .data() @@ -187,33 +189,23 @@ pub(crate) fn try_collect_left( if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? { - Ok(Some(hash_join.swap_inputs(PartitionMode::CollectLeft)?)) + match hash_join.swap_inputs(PartitionMode::CollectLeft) { + Ok(plan) => Ok(Some(plan)), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } } else { - Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))) + build_collect_left_exec(hash_join, left, right) } } - (true, false) => Ok(Some(Arc::new(HashJoinExec::try_new( - Arc::clone(left), - Arc::clone(right), - hash_join.on().to_vec(), - hash_join.filter().cloned(), - hash_join.join_type(), - hash_join.projection.clone(), - PartitionMode::CollectLeft, - hash_join.null_equality(), - )?))), + (true, false) => build_collect_left_exec(hash_join, left, right), (false, true) => { if hash_join.join_type().supports_swap() { - hash_join.swap_inputs(PartitionMode::CollectLeft).map(Some) + match hash_join.swap_inputs(PartitionMode::CollectLeft) { + Ok(plan) => Ok(Some(plan)), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } } else { Ok(None) } @@ -222,6 +214,36 @@ pub(crate) fn try_collect_left( } } + +fn is_missing_join_columns(err: &DataFusionError) -> bool { + matches!( + err, + DataFusionError::Plan(msg) + if msg.contains("The left or right side of the join does not have all columns") + ) +} + +fn build_collect_left_exec( + hash_join: &HashJoinExec, + left: &Arc, + right: &Arc, +) -> Result>> { + match HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + PartitionMode::CollectLeft, + hash_join.null_equality(), + ) { + Ok(exec) => Ok(Some(Arc::new(exec))), + Err(err) if is_missing_join_columns(&err) => Ok(None), + Err(err) => Err(err), + } +} + /// Creates a partitioned hash join execution plan, swapping inputs if beneficial. /// /// Checks if the join order should be swapped based on the join type and input statistics. @@ -229,11 +251,30 @@ pub(crate) fn try_collect_left( /// creates a standard partitioned hash join. pub(crate) fn partitioned_hash_join( hash_join: &HashJoinExec, + enable_grace: bool, ) -> Result> { let left = hash_join.left(); let right = hash_join.right(); - if hash_join.join_type().supports_swap() && should_swap_join_order(&**left, &**right)? - { + let should_swap = hash_join.join_type().supports_swap() + && should_swap_join_order(&**left, &**right)?; + if enable_grace { + let grace = Arc::new(GraceHashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + hash_join.on().to_vec(), + hash_join.filter().cloned(), + hash_join.join_type(), + hash_join.projection.clone(), + hash_join.null_equality(), + )?); + return if should_swap { + grace.swap_inputs(PartitionMode::Partitioned) + } else { + Ok(grace) + }; + } + + if should_swap { hash_join.swap_inputs(PartitionMode::Partitioned) } else { Ok(Arc::new(HashJoinExec::try_new( @@ -255,6 +296,7 @@ fn statistical_join_selection_subrule( plan: Arc, collect_threshold_byte_size: usize, collect_threshold_num_rows: usize, + enable_grace_hash_join: bool, ) -> Result>> { let transformed = if let Some(hash_join) = plan.as_any().downcast_ref::() { @@ -266,12 +308,12 @@ fn statistical_join_selection_subrule( collect_threshold_num_rows, )? .map_or_else( - || partitioned_hash_join(hash_join).map(Some), + || partitioned_hash_join(hash_join, enable_grace_hash_join).map(Some), |v| Ok(Some(v)), )?, PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? .map_or_else( - || partitioned_hash_join(hash_join).map(Some), + || partitioned_hash_join(hash_join, enable_grace_hash_join).map(Some), |v| Ok(Some(v)), )?, PartitionMode::Partitioned => { diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs index 2c9482f93f892..47530688c0b98 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/exec.rs @@ -20,30 +20,27 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; -use crate::joins::utils::{ - reorder_output_after_swap, swap_join_projection, OnceFut, -}; +use crate::joins::utils::{reorder_output_after_swap, swap_join_projection, OnceFut}; use crate::joins::{JoinOn, JoinOnRef, PartitionMode}; use crate::projection::{ try_embed_projection, try_pushdown_through_join, EmbeddedProjection, JoinData, ProjectionExec, }; -use crate::spill::get_record_batch_memory_size; use crate::{ common::can_project, joins::utils::{ build_join_schema, check_join_is_valid, estimate_join_statistics, - symmetric_join_output_partitioning, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, + JoinFilter, }, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, - PlanProperties, SendableRecordBatchStream, Statistics, + DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, Statistics, }; use crate::{ExecutionPlanProperties, SpillManager}; use std::fmt; use std::fmt::Formatter; -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use std::{any::Any, vec}; use arrow::array::UInt32Array; @@ -52,21 +49,16 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - internal_err, plan_err, project_schema, JoinSide, JoinType, - NullEquality, Result, + internal_err, plan_err, project_schema, JoinSide, JoinType, NullEquality, Result, }; use datafusion_execution::TaskContext; -use datafusion_functions_aggregate_common::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; use datafusion_physical_expr::expressions::{lit, DynamicFilterPhysicalExpr}; use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; -use crate::joins::grace_hash_join::stream::{ - GraceAccumulator, GraceHashJoinStream, SpillFut, -}; -use crate::joins::hash_join::shared_bounds::SharedBoundsAccumulator; +use crate::joins::grace_hash_join::stream::{GraceHashJoinStream, SpillFut}; use crate::metrics::SpillMetrics; use crate::spill::spill_manager::SpillLocation; use ahash::RandomState; @@ -104,20 +96,8 @@ pub struct GraceHashJoinExec { pub null_equality: NullEquality, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, - /// Dynamic filter for pushing down to the probe side - /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. - /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. - dynamic_filter: Option, - accumulator: Arc, -} - -#[derive(Clone)] -struct HashJoinExecDynamicFilter { - /// Dynamic filter that we'll update with the results of the build side once that is done. - filter: Arc, - /// Bounds accumulator to keep track of the min/max bounds on the join keys for each partition. - /// It is lazily initialized during execution to make sure we use the actual execution time partition counts. - bounds_accumulator: OnceLock>, + /// Indicates whether dynamic filter pushdown is enabled for this join. + dynamic_filter_enabled: bool, } impl fmt::Debug for GraceHashJoinExec { @@ -135,7 +115,7 @@ impl fmt::Debug for GraceHashJoinExec { .field("column_indices", &self.column_indices) .field("null_equality", &self.null_equality) .field("cache", &self.cache) - // Explicitly exclude dynamic_filter to avoid runtime state differences in tests + // Intentionally omit dynamic_filter_enabled to keep debug output stable .finish() } } @@ -186,9 +166,6 @@ impl GraceHashJoinExec { &on, projection.as_ref(), )?; - let partitions = left.output_partitioning().partition_count(); - let accumulator = GraceAccumulator::new(partitions); - let metrics = ExecutionPlanMetricsSet::new(); // Initialize both dynamic filter and bounds accumulator to None // They will be set later if dynamic filtering is enabled @@ -205,8 +182,7 @@ impl GraceHashJoinExec { column_indices, null_equality, cache, - dynamic_filter: None, - accumulator, + dynamic_filter_enabled: false, }) } @@ -570,9 +546,8 @@ impl ExecutionPlan for GraceHashJoinExec { &self.on, self.projection.as_ref(), )?, - // Keep the dynamic filter, bounds accumulator will be reset - dynamic_filter: self.dynamic_filter.clone(), - accumulator: Arc::clone(&self.accumulator), + // Preserve dynamic filter enablement; state will be refreshed as needed + dynamic_filter_enabled: self.dynamic_filter_enabled, })) } @@ -590,9 +565,8 @@ impl ExecutionPlan for GraceHashJoinExec { column_indices: self.column_indices.clone(), null_equality: self.null_equality, cache: self.cache.clone(), - // Reset dynamic filter and bounds accumulator to initial state - dynamic_filter: None, - accumulator: Arc::clone(&self.accumulator), + // Reset dynamic filter state to initial configuration + dynamic_filter_enabled: false, })) } @@ -611,7 +585,7 @@ impl ExecutionPlan for GraceHashJoinExec { ); } - let enable_dynamic_filter_pushdown = self.dynamic_filter.is_some(); + let enable_dynamic_filter_pushdown = self.dynamic_filter_enabled; let join_metrics = Arc::new(BuildProbeJoinMetrics::new(partition, &self.metrics)); @@ -655,7 +629,6 @@ impl ExecutionPlan for GraceHashJoinExec { let on = self.on.clone(); let spill_left_clone = Arc::clone(&spill_left); let spill_right_clone = Arc::clone(&spill_right); - let accumulator_clone = Arc::clone(&self.accumulator); let join_metrics_clone = Arc::clone(&join_metrics); let spill_fut = OnceFut::new(async move { let (left_idx, right_idx) = partition_and_spill( @@ -671,14 +644,16 @@ impl ExecutionPlan for GraceHashJoinExec { partition, ) .await?; - accumulator_clone - .report_partition(partition, left_idx.clone(), right_idx.clone()) - .await; Ok(SpillFut::new(partition, left_idx, right_idx)) }); + let left_input_schema = self.left.schema(); + let right_input_schema = self.right.schema(); + Ok(Box::pin(GraceHashJoinStream::new( self.schema(), + left_input_schema, + right_input_schema, spill_fut, spill_left, spill_right, @@ -690,7 +665,6 @@ impl ExecutionPlan for GraceHashJoinExec { column_indices_after_projection, join_metrics, context, - Arc::clone(&self.accumulator), ))) } @@ -825,9 +799,7 @@ impl ExecutionPlan for GraceHashJoinExec { // Note that we don't check PushdDownPredicate::discrimnant because even if nothing said // "yes, I can fully evaluate this filter" things might still use it for statistics -> it's worth updating let predicate = Arc::clone(&filter.predicate); - if let Ok(dynamic_filter) = - Arc::downcast::(predicate) - { + if Arc::downcast::(predicate).is_ok() { // We successfully pushed down our self filter - we need to make a new node with the dynamic filter let new_node = Arc::new(GraceHashJoinExec { left: Arc::clone(&self.left), @@ -842,11 +814,7 @@ impl ExecutionPlan for GraceHashJoinExec { column_indices: self.column_indices.clone(), null_equality: self.null_equality, cache: self.cache.clone(), - dynamic_filter: Some(HashJoinExecDynamicFilter { - filter: dynamic_filter, - bounds_accumulator: OnceLock::new(), - }), - accumulator: Arc::clone(&self.accumulator), + dynamic_filter_enabled: true, }); result = result.with_updated_node(new_node as Arc); } @@ -855,7 +823,6 @@ impl ExecutionPlan for GraceHashJoinExec { } } - #[allow(clippy::too_many_arguments)] pub async fn partition_and_spill( random_state: RandomState, @@ -974,8 +941,8 @@ async fn partition_and_spill_one_side( // Prepare indexes let mut result = Vec::with_capacity(partitions.len()); - for (i, writer) in partitions.into_iter().enumerate() { - result.push(writer.finish(i)?); + for writer in partitions.into_iter() { + result.push(writer.finish()?); } // println!("spill_manager {:?}", spill_manager.metrics); Ok(result) @@ -984,8 +951,6 @@ async fn partition_and_spill_one_side( #[derive(Debug)] pub struct PartitionWriter { spill_manager: Arc, - total_rows: usize, - total_bytes: usize, chunks: Vec, } @@ -993,8 +958,6 @@ impl PartitionWriter { pub fn new(spill_manager: Arc) -> Self { Self { spill_manager, - total_rows: 0, - total_bytes: 0, chunks: vec![], } } @@ -1005,18 +968,13 @@ impl PartitionWriter { request_msg: &str, ) -> Result<()> { let loc = self.spill_manager.spill_batch_auto(batch, request_msg)?; - self.total_rows += batch.num_rows(); - self.total_bytes += get_record_batch_memory_size(batch); self.chunks.push(loc); Ok(()) } - pub fn finish(self, part_id: usize) -> Result { + pub fn finish(self) -> Result { Ok(PartitionIndex { - part_id, chunks: self.chunks, - total_rows: self.total_rows, - total_bytes: self.total_bytes, }) } } @@ -1031,15 +989,6 @@ impl PartitionWriter { /// Partition 3 -> [ spill_chunk_3_0.arrow, spill_chunk_3_1.arrow ] #[derive(Debug, Clone)] pub struct PartitionIndex { - /// Unique partition identifier (0..N-1) - pub part_id: usize, - - /// Total number of rows in this partition - pub total_rows: usize, - - /// Total size in bytes of all batches in this partition - pub total_bytes: usize, - /// Collection of spill locations (each corresponds to one batch written /// by [`PartitionWriter::spill_batch_auto`]) pub chunks: Vec, @@ -1049,9 +998,7 @@ pub struct PartitionIndex { mod tests { use super::*; use crate::test::TestMemoryExec; - use crate::{ - common, expressions::Column, repartition::RepartitionExec, test::build_table_i32, - }; + use crate::{common, expressions::Column, repartition::RepartitionExec}; use crate::joins::HashJoinExec; use arrow::array::{ArrayRef, Int32Array}; @@ -1093,28 +1040,8 @@ mod tests { Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) } - fn build_table( - a: (&str, &Vec), - b: (&str, &Vec), - c: (&str, &Vec), - ) -> Arc { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() - } - #[tokio::test] async fn simple_grace_hash_join() -> Result<()> { - // let left = build_table( - // ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), - // ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), - // ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), - // ); - // let right = build_table( - // ("a2", &vec![1, 2]), - // ("b2", &vec![1, 2]), - // ("c2", &vec![14, 15]), - // ); let left = build_large_table("a1", "b1", "c1", 2000000); let right = build_large_table("a2", "b2", "c2", 5000000); let on = vec![( @@ -1168,7 +1095,7 @@ mod tests { batches.extend(v); } let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!("TOTAL ROWS = {}", total_rows); + assert_eq!(total_rows, 1_000_000); // print_batches(&*batches).unwrap(); // Asserting that operator-level reservation attempting to overallocate @@ -1231,7 +1158,7 @@ mod tests { ); } let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); - println!("TOTAL ROWS = {}", total_rows); + assert_eq!(total_rows, 1_000_000); // print_batches(&*batches).unwrap(); Ok(()) diff --git a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs index d028b0e8bcf87..da0de10a51889 100644 --- a/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs +++ b/datafusion/physical-plan/src/joins/grace_hash_join/stream.rs @@ -34,22 +34,17 @@ use crate::empty::EmptyExec; use crate::joins::grace_hash_join::exec::PartitionIndex; use crate::joins::{HashJoinExec, PartitionMode}; use crate::test::TestMemoryExec; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{JoinType, NullEquality, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalExprRef; use futures::{ready, Stream, StreamExt}; -use tokio::sync::Mutex; enum GraceJoinState { /// Waiting for the partitioning phase (Phase 1) to finish WaitPartitioning, - WaitAllPartitions { - wait_all_fut: Option>>, - }, - /// Currently joining partition `current` JoinPartition { current: usize, @@ -64,6 +59,8 @@ enum GraceJoinState { pub struct GraceHashJoinStream { schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, spill_fut: OnceFut, spill_left: Arc, spill_right: Arc, @@ -75,27 +72,21 @@ pub struct GraceHashJoinStream { column_indices: Vec, join_metrics: Arc, context: Arc, - accumulator: Arc, state: GraceJoinState, } #[derive(Debug, Clone)] pub struct SpillFut { - partition: usize, left: Vec, right: Vec, } impl SpillFut { pub(crate) fn new( - partition: usize, + _partition: usize, left: Vec, right: Vec, ) -> Self { - SpillFut { - partition, - left, - right, - } + SpillFut { left, right } } } @@ -108,6 +99,8 @@ impl RecordBatchStream for GraceHashJoinStream { impl GraceHashJoinStream { pub fn new( schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, spill_fut: OnceFut, spill_left: Arc, spill_right: Arc, @@ -119,10 +112,11 @@ impl GraceHashJoinStream { column_indices: Vec, join_metrics: Arc, context: Arc, - accumulator: Arc, ) -> Self { Self { schema, + left_input_schema, + right_input_schema, spill_fut, spill_left, spill_right, @@ -134,7 +128,6 @@ impl GraceHashJoinStream { column_indices, join_metrics, context, - accumulator, state: GraceJoinState::WaitPartitioning, } } @@ -148,47 +141,16 @@ impl GraceHashJoinStream { match &mut self.state { GraceJoinState::WaitPartitioning => { let shared = ready!(self.spill_fut.get_shared(cx))?; - - let acc = Arc::clone(&self.accumulator); - let left = shared.left.clone(); - let right = shared.right.clone(); - // Use 0 partition as the main - let wait_all_fut = if shared.partition == 0 { - OnceFut::new(async move { - acc.report_partition(shared.partition, left, right).await; - let all = acc.wait_all().await; - Ok(all) - }) - } else { - OnceFut::new(async move { - acc.report_partition(shared.partition, left, right).await; - acc.wait_ready().await; - Ok(vec![]) - }) - }; - self.state = GraceJoinState::WaitAllPartitions { - wait_all_fut: Some(wait_all_fut), + let parts = Arc::new(vec![(*shared).clone()]); + self.state = GraceJoinState::JoinPartition { + current: 0, + all_parts: parts, + current_stream: None, + left_fut: None, + right_fut: None, }; continue; } - GraceJoinState::WaitAllPartitions { wait_all_fut } => { - if let Some(fut) = wait_all_fut { - let all_arc = ready!(fut.get_shared(cx))?; - let mut all = (*all_arc).clone(); - all.sort_by_key(|s| s.partition); - - self.state = GraceJoinState::JoinPartition { - current: 0, - all_parts: Arc::from(all), - current_stream: None, - left_fut: None, - right_fut: None, - }; - continue; - } else { - return Poll::Pending; - } - } GraceJoinState::JoinPartition { current, all_parts, @@ -223,6 +185,8 @@ impl GraceHashJoinStream { let stream = build_in_memory_join_stream( Arc::clone(&self.schema), + Arc::clone(&self.left_input_schema), + Arc::clone(&self.right_input_schema), left_batches, right_batches, &self.on_left, @@ -282,6 +246,8 @@ fn load_partition_async( /// Build an in-memory HashJoinExec for one pair of spilled partitions fn build_in_memory_join_stream( output_schema: SchemaRef, + left_input_schema: SchemaRef, + right_input_schema: SchemaRef, left_batches: Vec, right_batches: Vec, on_left: &[PhysicalExprRef], @@ -297,22 +263,15 @@ fn build_in_memory_join_stream( return EmptyExec::new(output_schema).execute(0, Arc::clone(context)); } - let left_schema = left_batches - .first() - .map(|b| b.schema()) - .unwrap_or_else(|| Arc::new(Schema::empty())); - - let right_schema = right_batches - .first() - .map(|b| b.schema()) - .unwrap_or_else(|| Arc::new(Schema::empty())); - // Build memory execution nodes for each side - let left_plan: Arc = - Arc::new(TestMemoryExec::try_new(&[left_batches], left_schema, None)?); + let left_plan: Arc = Arc::new(TestMemoryExec::try_new( + &[left_batches], + left_input_schema, + None, + )?); let right_plan: Arc = Arc::new(TestMemoryExec::try_new( &[right_batches], - right_schema, + right_input_schema, None, )?); @@ -349,61 +308,3 @@ impl Stream for GraceHashJoinStream { self.poll_next_impl(cx) } } - -#[derive(Debug)] -pub struct GraceAccumulator { - expected: usize, - collected: Mutex>, - notify: tokio::sync::Notify, -} - -impl GraceAccumulator { - pub fn new(expected: usize) -> Arc { - Arc::new(Self { - expected, - collected: Mutex::new(vec![]), - notify: tokio::sync::Notify::new(), - }) - } - - pub async fn report_partition( - &self, - part_id: usize, - left_idx: Vec, - right_idx: Vec, - ) { - let mut guard = self.collected.lock().await; - if let Some(pos) = guard.iter().position(|s| s.partition == part_id) { - guard[pos] = SpillFut::new(part_id, left_idx, right_idx); - } else { - guard.push(SpillFut::new(part_id, left_idx, right_idx)); - } - - if guard.len() == self.expected { - self.notify.notify_waiters(); - } - } - - pub async fn wait_all(&self) -> Vec { - loop { - { - let guard = self.collected.lock().await; - if guard.len() == self.expected { - return guard.clone(); - } - } - self.notify.notified().await; - } - } - pub async fn wait_ready(&self) { - loop { - { - let guard = self.collected.lock().await; - if guard.len() == self.expected { - return; - } - } - self.notify.notified().await; - } - } -} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 8ee4c3de430a3..ac1ad30c247b6 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,6 +20,7 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; +pub use grace_hash_join::GraceHashJoinExec; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 70d6caf7642bc..07b71ff159ca1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -732,6 +732,7 @@ message PhysicalPlanNode { GenerateSeriesNode generate_series = 33; SortMergeJoinExecNode sort_merge_join = 34; MemoryScanExecNode memory_scan = 35; + GraceHashJoinExecNode grace_hash_join = 36; } } @@ -1074,6 +1075,16 @@ message HashJoinExecNode { repeated uint32 projection = 9; } +message GraceHashJoinExecNode { + PhysicalPlanNode left = 1; + PhysicalPlanNode right = 2; + repeated JoinOn on = 3; + datafusion_common.JoinType join_type = 4; + datafusion_common.NullEquality null_equality = 5; + JoinFilter filter = 6; + repeated uint32 projection = 7; +} + enum StreamPartitionMode { SINGLE_PARTITION = 0; PARTITIONED_EXEC = 1; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 83f662e611120..4d7f8241c7100 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -7717,6 +7717,208 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for GraceHashJoinExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.null_equality != 0 { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + if !self.projection.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.GraceHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.null_equality != 0 { + let v = super::datafusion_common::NullEquality::try_from(self.null_equality) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.null_equality)))?; + struct_ser.serialize_field("nullEquality", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for GraceHashJoinExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "left", + "right", + "on", + "join_type", + "joinType", + "null_equality", + "nullEquality", + "filter", + "projection", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + On, + JoinType, + NullEquality, + Filter, + Projection, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "nullEquality" | "null_equality" => Ok(GeneratedField::NullEquality), + "filter" => Ok(GeneratedField::Filter), + "projection" => Ok(GeneratedField::Projection), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GraceHashJoinExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.GraceHashJoinExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut null_equality__ = None; + let mut filter__ = None; + let mut projection__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEquality => { + if null_equality__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEquality")); + } + null_equality__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(GraceHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + null_equality: null_equality__.unwrap_or_default(), + filter: filter__, + projection: projection__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.GraceHashJoinExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for GroupingSetNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17022,6 +17224,9 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::MemoryScan(v) => { struct_ser.serialize_field("memoryScan", v)?; } + physical_plan_node::PhysicalPlanType::GraceHashJoin(v) => { + struct_ser.serialize_field("graceHashJoin", v)?; + } } } struct_ser.end() @@ -17087,6 +17292,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "sortMergeJoin", "memory_scan", "memoryScan", + "grace_hash_join", + "graceHashJoin", ]; #[allow(clippy::enum_variant_names)] @@ -17125,6 +17332,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { GenerateSeries, SortMergeJoin, MemoryScan, + GraceHashJoin, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17180,6 +17388,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "generateSeries" | "generate_series" => Ok(GeneratedField::GenerateSeries), "sortMergeJoin" | "sort_merge_join" => Ok(GeneratedField::SortMergeJoin), "memoryScan" | "memory_scan" => Ok(GeneratedField::MemoryScan), + "graceHashJoin" | "grace_hash_join" => Ok(GeneratedField::GraceHashJoin), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17438,6 +17647,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("memoryScan")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::MemoryScan) +; + } + GeneratedField::GraceHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("graceHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GraceHashJoin) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index cc19add6fbe9e..d5520c6843b69 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1055,7 +1055,7 @@ pub mod table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36" )] pub physical_plan_type: ::core::option::Option, } @@ -1133,6 +1133,8 @@ pub mod physical_plan_node { SortMergeJoin(::prost::alloc::boxed::Box), #[prost(message, tag = "35")] MemoryScan(super::MemoryScanExecNode), + #[prost(message, tag = "36")] + GraceHashJoin(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1636,6 +1638,23 @@ pub struct HashJoinExecNode { pub projection: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct GraceHashJoinExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub left: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub right: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub on: ::prost::alloc::vec::Vec, + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] + pub join_type: i32, + #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "5")] + pub null_equality: i32, + #[prost(message, optional, tag = "6")] + pub filter: ::core::option::Option, + #[prost(uint32, repeated, tag = "7")] + pub projection: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index e577de5b1d0e0..556e76f2642f6 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -77,7 +77,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion::physical_plan::joins::{ - CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, + CrossJoinExec, GraceHashJoinExec, NestedLoopJoinExec, SortMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, }; use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode}; @@ -214,6 +214,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, ), + PhysicalPlanType::GraceHashJoin(grace_hash_join) => self + .try_into_grace_hash_join_physical_plan( + grace_hash_join, + ctx, + runtime, + extension_codec, + ), PhysicalPlanType::SymmetricHashJoin(sym_join) => self .try_into_symmetric_hash_join_physical_plan( sym_join, @@ -365,6 +372,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ); } + if let Some(exec) = plan.downcast_ref::() { + return protobuf::PhysicalPlanNode::try_from_grace_hash_join_exec( + exec, + extension_codec, + ); + } + if let Some(exec) = plan.downcast_ref::() { return protobuf::PhysicalPlanNode::try_from_symmetric_hash_join_exec( exec, @@ -1266,6 +1280,116 @@ impl protobuf::PhysicalPlanNode { )?)) } + fn try_into_grace_hash_join_physical_plan( + &self, + grace_join: &protobuf::GraceHashJoinExecNode, + ctx: &SessionContext, + runtime: &RuntimeEnv, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result> { + let left = into_physical_plan(&grace_join.left, ctx, runtime, extension_codec)?; + let right = into_physical_plan(&grace_join.right, ctx, runtime, extension_codec)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + let on = grace_join + .on + .iter() + .map(|col| { + let left_expr = parse_physical_expr( + col.left.as_ref().ok_or_else(|| { + proto_error("GraceHashJoinExecNode missing left expr") + })?, + ctx, + left_schema.as_ref(), + extension_codec, + )?; + let right_expr = parse_physical_expr( + col.right.as_ref().ok_or_else(|| { + proto_error("GraceHashJoinExecNode missing right expr") + })?, + ctx, + right_schema.as_ref(), + extension_codec, + )?; + Ok((left_expr, right_expr)) + }) + .collect::>()?; + let join_type = + protobuf::JoinType::try_from(grace_join.join_type).map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode with unknown JoinType {}", + grace_join.join_type + )) + })?; + let null_equality = protobuf::NullEquality::try_from(grace_join.null_equality) + .map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode with unknown NullEquality {}", + grace_join.null_equality + )) + })?; + let filter = grace_join + .filter + .as_ref() + .map(|f| { + let schema = f + .schema + .as_ref() + .ok_or_else(|| proto_error("Missing JoinFilter schema"))? + .try_into()?; + + let expression = parse_physical_expr( + f.expression.as_ref().ok_or_else(|| { + proto_error("Unexpected empty filter expression") + })?, + ctx, + &schema, + extension_codec, + )?; + let column_indices = f + .column_indices + .iter() + .map(|i| { + let side = protobuf::JoinSide::try_from(i.side).map_err(|_| { + proto_error(format!( + "Received a GraceHashJoinExecNode message with JoinSide in Filter {}", + i.side + )) + })?; + + Ok(ColumnIndex { + index: i.index as usize, + side: side.into(), + }) + }) + .collect::>>()?; + + Ok(JoinFilter::new(expression, column_indices, Arc::new(schema))) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + let projection = if !grace_join.projection.is_empty() { + Some( + grace_join + .projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + + Ok(Arc::new(GraceHashJoinExec::try_new( + left, + right, + on, + filter, + &join_type.into(), + projection, + null_equality.into(), + )?)) + } + fn try_into_symmetric_hash_join_physical_plan( &self, sym_join: &protobuf::SymmetricHashJoinExecNode, @@ -2222,6 +2346,75 @@ impl protobuf::PhysicalPlanNode { }) } + fn try_from_grace_hash_join_exec( + exec: &GraceHashJoinExec, + extension_codec: &dyn PhysicalExtensionCodec, + ) -> Result { + let left = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.left().to_owned(), + extension_codec, + )?; + let right = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.right().to_owned(), + extension_codec, + )?; + let on: Vec = exec + .on() + .iter() + .map(|tuple| { + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; + Ok::<_, DataFusionError>(protobuf::JoinOn { + left: Some(l), + right: Some(r), + }) + }) + .collect::>()?; + let join_type: protobuf::JoinType = exec.join_type().to_owned().into(); + let null_equality: protobuf::NullEquality = exec.null_equality().into(); + let filter = exec + .filter() + .as_ref() + .map(|f| { + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; + let column_indices = f + .column_indices() + .iter() + .map(|i| { + let side: protobuf::JoinSide = i.side.to_owned().into(); + protobuf::ColumnIndex { + index: i.index as u32, + side: side.into(), + } + }) + .collect(); + let schema = f.schema().as_ref().try_into()?; + Ok(protobuf::JoinFilter { + expression: Some(expression), + column_indices, + schema: Some(schema), + }) + }) + .map_or(Ok(None), |v: Result| v.map(Some))?; + + Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::GraceHashJoin(Box::new( + protobuf::GraceHashJoinExecNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + on, + join_type: join_type.into(), + null_equality: null_equality.into(), + filter, + projection: exec.projection.as_ref().map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), + }, + ))), + }) + } + fn try_from_symmetric_hash_join_exec( exec: &SymmetricHashJoinExec, extension_codec: &dyn PhysicalExtensionCodec, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a5357a132eef2..b3f9172a273e4 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -284,6 +284,41 @@ fn roundtrip_hash_join() -> Result<()> { Ok(()) } +#[test] +fn roundtrip_grace_hash_join() -> Result<()> { + let field_a = Field::new("col", DataType::Int64, false); + let schema_left = Schema::new(vec![field_a.clone()]); + let schema_right = Schema::new(vec![field_a]); + let on = vec![( + Arc::new(Column::new("col", schema_left.index_of("col")?)) as _, + Arc::new(Column::new("col", schema_right.index_of("col")?)) as _, + )]; + + let schema_left = Arc::new(schema_left); + let schema_right = Arc::new(schema_right); + for join_type in &[ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftAnti, + JoinType::RightAnti, + JoinType::LeftSemi, + JoinType::RightSemi, + ] { + roundtrip_test(Arc::new(GraceHashJoinExec::try_new( + Arc::new(EmptyExec::new(schema_left.clone())), + Arc::new(EmptyExec::new(schema_right.clone())), + on.clone(), + None, + join_type, + None, + NullEquality::NullEqualsNothing, + )?))?; + } + Ok(()) +} + #[test] fn roundtrip_nested_loop_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false); From ab4e5fa546500edfb9eb0d140b35bfc360d4b3be Mon Sep 17 00:00:00 2001 From: Denys Tsomenko Date: Sat, 15 Nov 2025 16:30:20 +0200 Subject: [PATCH 2/2] [Test]. Add ghj to single-partition path --- .../physical-optimizer/src/join_selection.rs | 62 ++++++++++++++----- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-optimizer/src/join_selection.rs b/datafusion/physical-optimizer/src/join_selection.rs index 54b65450374dc..35e2a50116afe 100644 --- a/datafusion/physical-optimizer/src/join_selection.rs +++ b/datafusion/physical-optimizer/src/join_selection.rs @@ -301,21 +301,53 @@ fn statistical_join_selection_subrule( let transformed = if let Some(hash_join) = plan.as_any().downcast_ref::() { match hash_join.partition_mode() { - PartitionMode::Auto => try_collect_left( - hash_join, - false, - collect_threshold_byte_size, - collect_threshold_num_rows, - )? - .map_or_else( - || partitioned_hash_join(hash_join, enable_grace_hash_join).map(Some), - |v| Ok(Some(v)), - )?, - PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? - .map_or_else( - || partitioned_hash_join(hash_join, enable_grace_hash_join).map(Some), - |v| Ok(Some(v)), - )?, + PartitionMode::Auto => { + if enable_grace_hash_join + { + Some(partitioned_hash_join( + hash_join, + enable_grace_hash_join, + )?) + } else { + try_collect_left( + hash_join, + false, + collect_threshold_byte_size, + collect_threshold_num_rows, + )? + .map_or_else( + || { + partitioned_hash_join( + hash_join, + enable_grace_hash_join, + ) + .map(Some) + }, + |v| Ok(Some(v)), + )? + } + } + PartitionMode::CollectLeft => { + if enable_grace_hash_join + { + Some(partitioned_hash_join( + hash_join, + enable_grace_hash_join, + )?) + } else { + try_collect_left(hash_join, true, 0, 0)? + .map_or_else( + || { + partitioned_hash_join( + hash_join, + enable_grace_hash_join, + ) + .map(Some) + }, + |v| Ok(Some(v)), + )? + } + } PartitionMode::Partitioned => { let left = hash_join.left(); let right = hash_join.right();