@@ -112,6 +112,17 @@ fn max_partitions_allowed_for_memory(memory_threshold: usize) -> usize {
112112 highest_power_of_two_leq ( slots)
113113}
114114
115+ fn per_partition_budget_bytes ( memory_threshold : usize , partitions : usize ) -> usize {
116+ let partitions = partitions. max ( 1 ) ;
117+ let mut budget = memory_threshold
118+ . checked_div ( partitions)
119+ . unwrap_or ( memory_threshold) ;
120+ if budget == 0 {
121+ budget = HYBRID_HASH_MIN_PARTITION_BYTES ;
122+ }
123+ budget. max ( HYBRID_HASH_MIN_PARTITION_BYTES )
124+ }
125+
115126#[ inline]
116127fn hhj_debug < F : FnOnce ( ) -> String > ( builder : F ) {
117128 if std:: env:: var ( "DATAFUSION_HHJ_DEBUG" ) . is_ok ( ) {
@@ -202,6 +213,7 @@ impl ProbePartition {
202213
203214/// Runtime state tracked per probe partition.
204215#[ cfg( not( feature = "hybrid_hash_join_scheduler" ) ) ]
216+ #[ cfg( not( feature = "hybrid_hash_join_scheduler" ) ) ]
205217pub ( super ) struct ProbePartitionState {
206218 buffered : ProbePartition ,
207219 batch_position : usize ,
@@ -383,6 +395,8 @@ pub(super) struct PartitionedHashJoinStream {
383395 pub memory_reservation : MemoryReservation ,
384396 /// Tracks how many repartition passes have been attempted
385397 pub partition_pass : usize ,
398+ /// Indicates whether the current pass has already prepared partitions for output
399+ pub partition_pass_output_started : bool ,
386400 /// Runtime environment
387401 pub runtime_env : Arc < RuntimeEnv > ,
388402 /// Scratch space for computing hashes
@@ -609,14 +623,8 @@ impl PartitionedHashJoinStream {
609623 return None ;
610624 }
611625
612- let mut per_partition_budget = self
613- . memory_threshold
614- . checked_div ( self . max_partition_count . max ( 1 ) )
615- . unwrap_or ( self . memory_threshold ) ;
616- if per_partition_budget == 0 {
617- per_partition_budget = HYBRID_HASH_MIN_PARTITION_BYTES ;
618- }
619- per_partition_budget = per_partition_budget. max ( HYBRID_HASH_MIN_PARTITION_BYTES ) ;
626+ let mut per_partition_budget =
627+ per_partition_budget_bytes ( self . memory_threshold , self . num_partitions ) ;
620628
621629 let rows_budget = self
622630 . batch_size
@@ -1229,9 +1237,6 @@ impl PartitionedHashJoinStream {
12291237 Poll :: Ready ( None ) => {
12301238 self . probe_stream_finished = true ;
12311239 for part_id in 0 ..self . num_partitions {
1232- if let Some ( state) = self . probe_states . get_mut ( part_id) {
1233- state. batch_position = 0 ;
1234- }
12351240 self . finalize_spilled_partition ( part_id) ?;
12361241 }
12371242 return Poll :: Ready ( Ok ( ( ) ) ) ;
@@ -1767,6 +1772,7 @@ impl PartitionedHashJoinStream {
17671772 build_spill_manager,
17681773 memory_reservation,
17691774 partition_pass : 0 ,
1775+ partition_pass_output_started : false ,
17701776 runtime_env,
17711777 hashes_buffer : Vec :: new ( ) ,
17721778 right_side_ordered,
@@ -1840,7 +1846,7 @@ impl PartitionedHashJoinStream {
18401846 self . max_partition_count = 1 ;
18411847 }
18421848
1843- let mut allow_repartition = true ;
1849+ let mut allow_repartition = ! self . partition_pass_output_started ;
18441850 loop {
18451851 hhj_debug ( || {
18461852 format ! (
@@ -1883,6 +1889,7 @@ impl PartitionedHashJoinStream {
18831889
18841890 self . num_partitions = next_count;
18851891 self . partition_pass += 1 ;
1892+ self . partition_pass_output_started = false ;
18861893 allow_repartition = true ;
18871894 }
18881895 }
@@ -1966,13 +1973,6 @@ impl PartitionedHashJoinStream {
19661973 repartition_request = Some ( next_count) ;
19671974 break ;
19681975 }
1969- } else {
1970- hhj_debug ( || {
1971- format ! (
1972- "partition {} exceeded global mem but bytes={} under per-part budget; skipping repartition" ,
1973- build_index, partition_estimate
1974- )
1975- } ) ;
19761976 }
19771977 }
19781978 if !self . runtime_env . disk_manager . tmp_files_enabled ( ) {
@@ -1998,13 +1998,6 @@ impl PartitionedHashJoinStream {
19981998 repartition_request = Some ( next_count) ;
19991999 break ;
20002000 }
2001- } else {
2002- hhj_debug ( || {
2003- format ! (
2004- "allocation failure for partition {} but bytes={} under per-part budget; spilling without repartition" ,
2005- build_index, partition_estimate
2006- )
2007- } ) ;
20082001 }
20092002 }
20102003 if !self . runtime_env . disk_manager . tmp_files_enabled ( ) {
@@ -2150,19 +2143,11 @@ impl PartitionedHashJoinStream {
21502143 } ) ;
21512144 }
21522145
2153- if ( max_spilled_bytes > self . memory_threshold || any_spilled) && allow_repartition
2146+ if allow_repartition
2147+ && ( max_spilled_bytes > self . memory_threshold || any_spilled)
2148+ && self . repartition_worthwhile ( max_spilled_bytes)
21542149 {
2155- if !self . repartition_worthwhile ( max_spilled_bytes) {
2156- hhj_debug ( || {
2157- format ! (
2158- "spilled partitions already near budget (max_spilled_bytes={} bytes, memory_threshold={} partitions={}, budget≈{})" ,
2159- max_spilled_bytes,
2160- self . memory_threshold,
2161- self . num_partitions,
2162- ( self . memory_threshold / self . num_partitions. max( 1 ) ) . max( 1 )
2163- )
2164- } ) ;
2165- } else if let Some ( next_count) = self . next_partition_count ( ) {
2150+ if let Some ( next_count) = self . next_partition_count ( ) {
21662151 hhj_debug ( || {
21672152 format ! (
21682153 "try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}" ,
@@ -2173,14 +2158,11 @@ impl PartitionedHashJoinStream {
21732158 )
21742159 } ) ;
21752160 return Ok ( PartitionBuildStatus :: NeedMorePartitions { next_count } ) ;
2176- } else {
2177- hhj_debug ( || {
2178- "spill detected but no further repartition possible" . to_string ( )
2179- } ) ;
21802161 }
21812162 }
21822163
21832164 self . prepare_partition_queue ( ) ;
2165+ self . partition_pass_output_started = true ;
21842166 self . transition_to_next_partition ( ) ;
21852167
21862168 Ok ( PartitionBuildStatus :: Ready ( StatefulStreamResult :: Continue ) )
@@ -2293,6 +2275,11 @@ impl PartitionedHashJoinStream {
22932275 state. active_values = values;
22942276 state. active_hashes = hashes;
22952277 state. active_offset = ( 0 , None ) ;
2278+ if state. batch_position >= state. buffered . batches . len ( ) {
2279+ state. buffered = ProbePartition :: new ( ) ;
2280+ state. batch_position = 0 ;
2281+ state. buffered_rows = 0 ;
2282+ }
22962283 if let Some ( b) = state. active_batch . as_ref ( ) {
22972284 state. consumed_rows =
22982285 state. consumed_rows . saturating_add ( b. num_rows ( ) ) ;
0 commit comments