Skip to content

Commit 6d14d4b

Browse files
committed
Fix duplicating rows with reseting a state
1 parent c78d2e9 commit 6d14d4b

1 file changed

Lines changed: 29 additions & 42 deletions

File tree

datafusion/physical-plan/src/joins/hash_join/partitioned.rs

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ fn max_partitions_allowed_for_memory(memory_threshold: usize) -> usize {
112112
highest_power_of_two_leq(slots)
113113
}
114114

115+
fn per_partition_budget_bytes(memory_threshold: usize, partitions: usize) -> usize {
116+
let partitions = partitions.max(1);
117+
let mut budget = memory_threshold
118+
.checked_div(partitions)
119+
.unwrap_or(memory_threshold);
120+
if budget == 0 {
121+
budget = HYBRID_HASH_MIN_PARTITION_BYTES;
122+
}
123+
budget.max(HYBRID_HASH_MIN_PARTITION_BYTES)
124+
}
125+
115126
#[inline]
116127
fn hhj_debug<F: FnOnce() -> String>(builder: F) {
117128
if std::env::var("DATAFUSION_HHJ_DEBUG").is_ok() {
@@ -202,6 +213,7 @@ impl ProbePartition {
202213

203214
/// Runtime state tracked per probe partition.
204215
#[cfg(not(feature = "hybrid_hash_join_scheduler"))]
216+
#[cfg(not(feature = "hybrid_hash_join_scheduler"))]
205217
pub(super) struct ProbePartitionState {
206218
buffered: ProbePartition,
207219
batch_position: usize,
@@ -383,6 +395,8 @@ pub(super) struct PartitionedHashJoinStream {
383395
pub memory_reservation: MemoryReservation,
384396
/// Tracks how many repartition passes have been attempted
385397
pub partition_pass: usize,
398+
/// Indicates whether the current pass has already prepared partitions for output
399+
pub partition_pass_output_started: bool,
386400
/// Runtime environment
387401
pub runtime_env: Arc<RuntimeEnv>,
388402
/// Scratch space for computing hashes
@@ -609,14 +623,8 @@ impl PartitionedHashJoinStream {
609623
return None;
610624
}
611625

612-
let mut per_partition_budget = self
613-
.memory_threshold
614-
.checked_div(self.max_partition_count.max(1))
615-
.unwrap_or(self.memory_threshold);
616-
if per_partition_budget == 0 {
617-
per_partition_budget = HYBRID_HASH_MIN_PARTITION_BYTES;
618-
}
619-
per_partition_budget = per_partition_budget.max(HYBRID_HASH_MIN_PARTITION_BYTES);
626+
let mut per_partition_budget =
627+
per_partition_budget_bytes(self.memory_threshold, self.num_partitions);
620628

621629
let rows_budget = self
622630
.batch_size
@@ -1229,9 +1237,6 @@ impl PartitionedHashJoinStream {
12291237
Poll::Ready(None) => {
12301238
self.probe_stream_finished = true;
12311239
for part_id in 0..self.num_partitions {
1232-
if let Some(state) = self.probe_states.get_mut(part_id) {
1233-
state.batch_position = 0;
1234-
}
12351240
self.finalize_spilled_partition(part_id)?;
12361241
}
12371242
return Poll::Ready(Ok(()));
@@ -1767,6 +1772,7 @@ impl PartitionedHashJoinStream {
17671772
build_spill_manager,
17681773
memory_reservation,
17691774
partition_pass: 0,
1775+
partition_pass_output_started: false,
17701776
runtime_env,
17711777
hashes_buffer: Vec::new(),
17721778
right_side_ordered,
@@ -1840,7 +1846,7 @@ impl PartitionedHashJoinStream {
18401846
self.max_partition_count = 1;
18411847
}
18421848

1843-
let mut allow_repartition = true;
1849+
let mut allow_repartition = !self.partition_pass_output_started;
18441850
loop {
18451851
hhj_debug(|| {
18461852
format!(
@@ -1883,6 +1889,7 @@ impl PartitionedHashJoinStream {
18831889

18841890
self.num_partitions = next_count;
18851891
self.partition_pass += 1;
1892+
self.partition_pass_output_started = false;
18861893
allow_repartition = true;
18871894
}
18881895
}
@@ -1966,13 +1973,6 @@ impl PartitionedHashJoinStream {
19661973
repartition_request = Some(next_count);
19671974
break;
19681975
}
1969-
} else {
1970-
hhj_debug(|| {
1971-
format!(
1972-
"partition {} exceeded global mem but bytes={} under per-part budget; skipping repartition",
1973-
build_index, partition_estimate
1974-
)
1975-
});
19761976
}
19771977
}
19781978
if !self.runtime_env.disk_manager.tmp_files_enabled() {
@@ -1998,13 +1998,6 @@ impl PartitionedHashJoinStream {
19981998
repartition_request = Some(next_count);
19991999
break;
20002000
}
2001-
} else {
2002-
hhj_debug(|| {
2003-
format!(
2004-
"allocation failure for partition {} but bytes={} under per-part budget; spilling without repartition",
2005-
build_index, partition_estimate
2006-
)
2007-
});
20082001
}
20092002
}
20102003
if !self.runtime_env.disk_manager.tmp_files_enabled() {
@@ -2150,19 +2143,11 @@ impl PartitionedHashJoinStream {
21502143
});
21512144
}
21522145

2153-
if (max_spilled_bytes > self.memory_threshold || any_spilled) && allow_repartition
2146+
if allow_repartition
2147+
&& (max_spilled_bytes > self.memory_threshold || any_spilled)
2148+
&& self.repartition_worthwhile(max_spilled_bytes)
21542149
{
2155-
if !self.repartition_worthwhile(max_spilled_bytes) {
2156-
hhj_debug(|| {
2157-
format!(
2158-
"spilled partitions already near budget (max_spilled_bytes={} bytes, memory_threshold={} partitions={}, budget≈{})",
2159-
max_spilled_bytes,
2160-
self.memory_threshold,
2161-
self.num_partitions,
2162-
(self.memory_threshold / self.num_partitions.max(1)).max(1)
2163-
)
2164-
});
2165-
} else if let Some(next_count) = self.next_partition_count() {
2150+
if let Some(next_count) = self.next_partition_count() {
21662151
hhj_debug(|| {
21672152
format!(
21682153
"try_partition_build_side repartition due to spill (max_spilled_bytes={} threshold={} any_spilled={}) next_count={}",
@@ -2173,14 +2158,11 @@ impl PartitionedHashJoinStream {
21732158
)
21742159
});
21752160
return Ok(PartitionBuildStatus::NeedMorePartitions { next_count });
2176-
} else {
2177-
hhj_debug(|| {
2178-
"spill detected but no further repartition possible".to_string()
2179-
});
21802161
}
21812162
}
21822163

21832164
self.prepare_partition_queue();
2165+
self.partition_pass_output_started = true;
21842166
self.transition_to_next_partition();
21852167

21862168
Ok(PartitionBuildStatus::Ready(StatefulStreamResult::Continue))
@@ -2293,6 +2275,11 @@ impl PartitionedHashJoinStream {
22932275
state.active_values = values;
22942276
state.active_hashes = hashes;
22952277
state.active_offset = (0, None);
2278+
if state.batch_position >= state.buffered.batches.len() {
2279+
state.buffered = ProbePartition::new();
2280+
state.batch_position = 0;
2281+
state.buffered_rows = 0;
2282+
}
22962283
if let Some(b) = state.active_batch.as_ref() {
22972284
state.consumed_rows =
22982285
state.consumed_rows.saturating_add(b.num_rows());

0 commit comments

Comments
 (0)