From d506884fa1feaf236e671aba3e518e38fb53922b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:17:48 -0600 Subject: [PATCH 1/9] feat: move bytes_to_i128 to common crate --- native/common/src/lib.rs | 2 ++ native/common/src/utils.rs | 37 ++++++++++++++++++++++++++++++ native/core/Cargo.toml | 1 + native/core/src/execution/utils.rs | 21 +---------------- 4 files changed, 41 insertions(+), 20 deletions(-) create mode 100644 native/common/src/utils.rs diff --git a/native/common/src/lib.rs b/native/common/src/lib.rs index 9319d7347f..86ae7704f0 100644 --- a/native/common/src/lib.rs +++ b/native/common/src/lib.rs @@ -17,6 +17,8 @@ mod error; mod query_context; +mod utils; pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; +pub use utils::bytes_to_i128; diff --git a/native/common/src/utils.rs b/native/common/src/utils.rs new file mode 100644 index 0000000000..12283db30d --- /dev/null +++ b/native/common/src/utils.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Converts a slice of bytes to i128. The bytes are serialized in big-endian order by +/// `BigInteger.toByteArray()` in Java. +pub fn bytes_to_i128(slice: &[u8]) -> i128 { + let mut bytes = [0; 16]; + let mut i = 0; + while i != 16 && i != slice.len() { + bytes[i] = slice[slice.len() - 1 - i]; + i += 1; + } + + // if the decimal is negative, we need to flip all the bits + if (slice[0] as i8) < 0 { + while i < 16 { + bytes[i] = !bytes[i]; + i += 1; + } + } + + i128::from_le_bytes(bytes) +} diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 3f305a631d..1da8bed207 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -64,6 +64,7 @@ datafusion-spark = { workspace = true } once_cell = "1.18.0" crc32fast = "1.3.2" simd-adler32 = "0.3.7" +datafusion-comet-common = { workspace = true } datafusion-comet-spark-expr = { workspace = true } datafusion-comet-jni-bridge = { workspace = true } datafusion-comet-proto = { workspace = true } diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index f95423aa70..2fe6f8758f 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -97,23 +97,4 @@ impl SparkArrowConvert for ArrayData { } } -/// Converts a slice of bytes to i128. The bytes are serialized in big-endian order by -/// `BigInteger.toByteArray()` in Java. -pub fn bytes_to_i128(slice: &[u8]) -> i128 { - let mut bytes = [0; 16]; - let mut i = 0; - while i != 16 && i != slice.len() { - bytes[i] = slice[slice.len() - 1 - i]; - i += 1; - } - - // if the decimal is negative, we need to flip all the bits - if (slice[0] as i8) < 0 { - while i < 16 { - bytes[i] = !bytes[i]; - i += 1; - } - } - - i128::from_le_bytes(bytes) -} +pub use datafusion_comet_common::bytes_to_i128; From 34cb75ac571ac2eebb85d0e2ae5ca84beb583623 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:18:59 -0600 Subject: [PATCH 2/9] feat: move tracing module to common crate Move tracing module from native/core to native/common, replacing once_cell::sync::Lazy with std::sync::LazyLock and making all items pub. Replace the core version with a glob re-export. --- native/common/src/lib.rs | 1 + native/common/src/tracing.rs | 141 +++++++++++++++++++++++++++ native/core/src/execution/tracing.rs | 126 +----------------------- 3 files changed, 143 insertions(+), 125 deletions(-) create mode 100644 native/common/src/tracing.rs diff --git a/native/common/src/lib.rs b/native/common/src/lib.rs index 86ae7704f0..a9549badb1 100644 --- a/native/common/src/lib.rs +++ b/native/common/src/lib.rs @@ -17,6 +17,7 @@ mod error; mod query_context; +pub mod tracing; mod utils; pub use error::{decimal_overflow_error, SparkError, SparkErrorWithContext, SparkResult}; diff --git a/native/common/src/tracing.rs b/native/common/src/tracing.rs new file mode 100644 index 0000000000..76598fd5ac --- /dev/null +++ b/native/common/src/tracing.rs @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::instant::Instant; +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Write}; +use std::sync::{Arc, LazyLock, Mutex}; + +pub static RECORDER: LazyLock = LazyLock::new(Recorder::new); + +/// Log events using Chrome trace format JSON +/// https://github.com/catapult-project/catapult/blob/main/tracing/README.md +pub struct Recorder { + now: Instant, + writer: Arc>>, +} + +impl Recorder { + pub fn new() -> Self { + let file = OpenOptions::new() + .create(true) + .append(true) + .open("comet-event-trace.json") + .expect("Error writing tracing"); + + let mut writer = BufWriter::new(file); + + // Write start of JSON array. Note that there is no requirement to write + // the closing ']'. + writer + .write_all("[ ".as_bytes()) + .expect("Error writing tracing"); + Self { + now: Instant::now(), + writer: Arc::new(Mutex::new(writer)), + } + } + pub fn begin_task(&self, name: &str) { + self.log_event(name, "B") + } + + pub fn end_task(&self, name: &str) { + self.log_event(name, "E") + } + + pub fn log_memory_usage(&self, name: &str, usage_bytes: u64) { + let usage_mb = (usage_bytes as f64 / 1024.0 / 1024.0) as usize; + let json = format!( + "{{ \"name\": \"{name}\", \"cat\": \"PERF\", \"ph\": \"C\", \"pid\": 1, \"tid\": {}, \"ts\": {}, \"args\": {{ \"{name}\": {usage_mb} }} }},\n", + Self::get_thread_id(), + self.now.elapsed().as_micros() + ); + let mut writer = self.writer.lock().unwrap(); + writer + .write_all(json.as_bytes()) + .expect("Error writing tracing"); + } + + fn log_event(&self, name: &str, ph: &str) { + let json = format!( + "{{ \"name\": \"{}\", \"cat\": \"PERF\", \"ph\": \"{ph}\", \"pid\": 1, \"tid\": {}, \"ts\": {} }},\n", + name, + Self::get_thread_id(), + self.now.elapsed().as_micros() + ); + let mut writer = self.writer.lock().unwrap(); + writer + .write_all(json.as_bytes()) + .expect("Error writing tracing"); + } + + fn get_thread_id() -> u64 { + let thread_id = std::thread::current().id(); + format!("{thread_id:?}") + .trim_start_matches("ThreadId(") + .trim_end_matches(")") + .parse() + .expect("Error parsing thread id") + } +} + +pub fn trace_begin(name: &str) { + RECORDER.begin_task(name); +} + +pub fn trace_end(name: &str) { + RECORDER.end_task(name); +} + +pub fn log_memory_usage(name: &str, value: u64) { + RECORDER.log_memory_usage(name, value); +} + +pub fn with_trace(label: &str, tracing_enabled: bool, f: F) -> T +where + F: FnOnce() -> T, +{ + if tracing_enabled { + trace_begin(label); + } + + let result = f(); + + if tracing_enabled { + trace_end(label); + } + + result +} + +pub async fn with_trace_async(label: &str, tracing_enabled: bool, f: F) -> T +where + F: FnOnce() -> Fut, + Fut: std::future::Future, +{ + if tracing_enabled { + trace_begin(label); + } + + let result = f().await; + + if tracing_enabled { + trace_end(label); + } + + result +} diff --git a/native/core/src/execution/tracing.rs b/native/core/src/execution/tracing.rs index 01351565f5..b02006efb9 100644 --- a/native/core/src/execution/tracing.rs +++ b/native/core/src/execution/tracing.rs @@ -15,128 +15,4 @@ // specific language governing permissions and limitations // under the License. -use datafusion::common::instant::Instant; -use once_cell::sync::Lazy; -use std::fs::{File, OpenOptions}; -use std::io::{BufWriter, Write}; -use std::sync::{Arc, Mutex}; - -pub(crate) static RECORDER: Lazy = Lazy::new(Recorder::new); - -/// Log events using Chrome trace format JSON -/// https://github.com/catapult-project/catapult/blob/main/tracing/README.md -pub struct Recorder { - now: Instant, - writer: Arc>>, -} - -impl Recorder { - pub fn new() -> Self { - let file = OpenOptions::new() - .create(true) - .append(true) - .open("comet-event-trace.json") - .expect("Error writing tracing"); - - let mut writer = BufWriter::new(file); - - // Write start of JSON array. Note that there is no requirement to write - // the closing ']'. - writer - .write_all("[ ".as_bytes()) - .expect("Error writing tracing"); - Self { - now: Instant::now(), - writer: Arc::new(Mutex::new(writer)), - } - } - pub fn begin_task(&self, name: &str) { - self.log_event(name, "B") - } - - pub fn end_task(&self, name: &str) { - self.log_event(name, "E") - } - - pub fn log_memory_usage(&self, name: &str, usage_bytes: u64) { - let usage_mb = (usage_bytes as f64 / 1024.0 / 1024.0) as usize; - let json = format!( - "{{ \"name\": \"{name}\", \"cat\": \"PERF\", \"ph\": \"C\", \"pid\": 1, \"tid\": {}, \"ts\": {}, \"args\": {{ \"{name}\": {usage_mb} }} }},\n", - Self::get_thread_id(), - self.now.elapsed().as_micros() - ); - let mut writer = self.writer.lock().unwrap(); - writer - .write_all(json.as_bytes()) - .expect("Error writing tracing"); - } - - fn log_event(&self, name: &str, ph: &str) { - let json = format!( - "{{ \"name\": \"{}\", \"cat\": \"PERF\", \"ph\": \"{ph}\", \"pid\": 1, \"tid\": {}, \"ts\": {} }},\n", - name, - Self::get_thread_id(), - self.now.elapsed().as_micros() - ); - let mut writer = self.writer.lock().unwrap(); - writer - .write_all(json.as_bytes()) - .expect("Error writing tracing"); - } - - fn get_thread_id() -> u64 { - let thread_id = std::thread::current().id(); - format!("{thread_id:?}") - .trim_start_matches("ThreadId(") - .trim_end_matches(")") - .parse() - .expect("Error parsing thread id") - } -} - -pub(crate) fn trace_begin(name: &str) { - RECORDER.begin_task(name); -} - -pub(crate) fn trace_end(name: &str) { - RECORDER.end_task(name); -} - -pub(crate) fn log_memory_usage(name: &str, value: u64) { - RECORDER.log_memory_usage(name, value); -} - -pub(crate) fn with_trace(label: &str, tracing_enabled: bool, f: F) -> T -where - F: FnOnce() -> T, -{ - if tracing_enabled { - trace_begin(label); - } - - let result = f(); - - if tracing_enabled { - trace_end(label); - } - - result -} - -pub(crate) async fn with_trace_async(label: &str, tracing_enabled: bool, f: F) -> T -where - F: FnOnce() -> Fut, - Fut: std::future::Future, -{ - if tracing_enabled { - trace_begin(label); - } - - let result = f().await; - - if tracing_enabled { - trace_end(label); - } - - result -} +pub(crate) use datafusion_comet_common::tracing::*; From d494fd98d00abc2d2946cf622f68270e96fbd3ed Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:21:28 -0600 Subject: [PATCH 3/9] feat: add datafusion-comet-shuffle crate skeleton --- native/Cargo.toml | 5 +- native/shuffle/Cargo.toml | 65 ++++++++++++++++++++++++ native/shuffle/benches/row_columnar.rs | 18 +++++++ native/shuffle/benches/shuffle_writer.rs | 18 +++++++ native/shuffle/src/lib.rs | 16 ++++++ 5 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 native/shuffle/Cargo.toml create mode 100644 native/shuffle/benches/row_columnar.rs create mode 100644 native/shuffle/benches/shuffle_writer.rs create mode 100644 native/shuffle/src/lib.rs diff --git a/native/Cargo.toml b/native/Cargo.toml index 693221b157..e75c1fd241 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [workspace] -default-members = ["core", "spark-expr", "common", "proto", "jni-bridge"] -members = ["core", "spark-expr", "common", "proto", "jni-bridge", "hdfs", "fs-hdfs"] +default-members = ["core", "spark-expr", "common", "proto", "jni-bridge", "shuffle"] +members = ["core", "spark-expr", "common", "proto", "jni-bridge", "shuffle", "hdfs", "fs-hdfs"] resolver = "2" [workspace.package] @@ -46,6 +46,7 @@ datafusion-comet-spark-expr = { path = "spark-expr" } datafusion-comet-common = { path = "common" } datafusion-comet-jni-bridge = { path = "jni-bridge" } datafusion-comet-proto = { path = "proto" } +datafusion-comet-shuffle = { path = "shuffle" } chrono = { version = "0.4", default-features = false, features = ["clock"] } chrono-tz = { version = "0.10" } futures = "0.3.32" diff --git a/native/shuffle/Cargo.toml b/native/shuffle/Cargo.toml new file mode 100644 index 0000000000..b446cfe3e2 --- /dev/null +++ b/native/shuffle/Cargo.toml @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-shuffle" +description = "Apache DataFusion Comet: shuffle writer and reader" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +publish = false + +[dependencies] +arrow = { workspace = true } +async-trait = { workspace = true } +bytes = { workspace = true } +crc32fast = "1.3.2" +datafusion = { workspace = true } +datafusion-comet-common = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } +datafusion-comet-spark-expr = { workspace = true } +futures = { workspace = true } +itertools = "0.14.0" +jni = "0.21" +lz4_flex = { version = "0.13.0", default-features = false, features = ["frame"] } +simd-adler32 = "0.3.7" +snap = "1.1" +tokio = { version = "1", features = ["rt-multi-thread"] } +zstd = "0.13.3" + +[dev-dependencies] +criterion = { version = "0.7", features = ["async", "async_tokio", "async_std"] } +datafusion = { workspace = true, features = ["parquet_encryption", "sql"] } +itertools = "0.14.0" +tempfile = "3.26.0" + +[lib] +name = "datafusion_comet_shuffle" +path = "src/lib.rs" + +[[bench]] +name = "shuffle_writer" +harness = false + +[[bench]] +name = "row_columnar" +harness = false diff --git a/native/shuffle/benches/row_columnar.rs b/native/shuffle/benches/row_columnar.rs new file mode 100644 index 0000000000..f0e5818f8b --- /dev/null +++ b/native/shuffle/benches/row_columnar.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +fn main() {} diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs new file mode 100644 index 0000000000..f0e5818f8b --- /dev/null +++ b/native/shuffle/benches/shuffle_writer.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +fn main() {} diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs new file mode 100644 index 0000000000..b248758bc1 --- /dev/null +++ b/native/shuffle/src/lib.rs @@ -0,0 +1,16 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. From 3ce760d0403e2c5f4ccb8f2226c90219485be37b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:25:51 -0600 Subject: [PATCH 4/9] feat: move shuffle source files to new crate --- native/Cargo.lock | 30 +- native/shuffle/Cargo.toml | 1 + native/shuffle/src/codec.rs | 239 +++ native/shuffle/src/comet_partitioning.rs | 71 + native/shuffle/src/lib.rs | 12 + native/shuffle/src/metrics.rs | 61 + native/shuffle/src/partitioners/mod.rs | 35 + .../src/partitioners/multi_partition.rs | 640 +++++++ .../partitioned_batch_iterator.rs | 110 ++ .../src/partitioners/single_partition.rs | 192 ++ native/shuffle/src/shuffle_writer.rs | 696 +++++++ native/shuffle/src/spark_unsafe/list.rs | 485 +++++ native/shuffle/src/spark_unsafe/map.rs | 121 ++ native/shuffle/src/spark_unsafe/mod.rs | 20 + native/shuffle/src/spark_unsafe/row.rs | 1696 +++++++++++++++++ .../shuffle/src/writers/buf_batch_writer.rs | 142 ++ native/shuffle/src/writers/mod.rs | 22 + .../shuffle/src/writers/partition_writer.rs | 124 ++ 18 files changed, 4695 insertions(+), 2 deletions(-) create mode 100644 native/shuffle/src/codec.rs create mode 100644 native/shuffle/src/comet_partitioning.rs create mode 100644 native/shuffle/src/metrics.rs create mode 100644 native/shuffle/src/partitioners/mod.rs create mode 100644 native/shuffle/src/partitioners/multi_partition.rs create mode 100644 native/shuffle/src/partitioners/partitioned_batch_iterator.rs create mode 100644 native/shuffle/src/partitioners/single_partition.rs create mode 100644 native/shuffle/src/shuffle_writer.rs create mode 100644 native/shuffle/src/spark_unsafe/list.rs create mode 100644 native/shuffle/src/spark_unsafe/map.rs create mode 100644 native/shuffle/src/spark_unsafe/mod.rs create mode 100644 native/shuffle/src/spark_unsafe/row.rs create mode 100644 native/shuffle/src/writers/buf_batch_writer.rs create mode 100644 native/shuffle/src/writers/mod.rs create mode 100644 native/shuffle/src/writers/partition_writer.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index 5f99c614b3..2df7677088 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1837,6 +1837,7 @@ dependencies = [ "crc32fast", "criterion", "datafusion", + "datafusion-comet-common", "datafusion-comet-jni-bridge", "datafusion-comet-objectstore-hdfs", "datafusion-comet-proto", @@ -1885,7 +1886,7 @@ dependencies = [ [[package]] name = "datafusion-comet-common" -version = "0.14.0" +version = "0.15.0" dependencies = [ "arrow", "datafusion", @@ -1911,7 +1912,7 @@ dependencies = [ [[package]] name = "datafusion-comet-jni-bridge" -version = "0.14.0" +version = "0.15.0" dependencies = [ "arrow", "assertables", @@ -1949,6 +1950,31 @@ dependencies = [ "prost-build", ] +[[package]] +name = "datafusion-comet-shuffle" +version = "0.15.0" +dependencies = [ + "arrow", + "async-trait", + "bytes", + "crc32fast", + "criterion", + "datafusion", + "datafusion-comet-common", + "datafusion-comet-jni-bridge", + "datafusion-comet-spark-expr", + "futures", + "itertools 0.14.0", + "jni", + "log", + "lz4_flex 0.13.0", + "simd-adler32", + "snap", + "tempfile", + "tokio", + "zstd", +] + [[package]] name = "datafusion-comet-spark-expr" version = "0.15.0" diff --git a/native/shuffle/Cargo.toml b/native/shuffle/Cargo.toml index b446cfe3e2..e28827edc2 100644 --- a/native/shuffle/Cargo.toml +++ b/native/shuffle/Cargo.toml @@ -40,6 +40,7 @@ datafusion-comet-spark-expr = { workspace = true } futures = { workspace = true } itertools = "0.14.0" jni = "0.21" +log = "0.4" lz4_flex = { version = "0.13.0", default-features = false, features = ["frame"] } simd-adler32 = "0.3.7" snap = "1.1" diff --git a/native/shuffle/src/codec.rs b/native/shuffle/src/codec.rs new file mode 100644 index 0000000000..c18489115a --- /dev/null +++ b/native/shuffle/src/codec.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_comet_jni_bridge::errors::{CometError, CometResult}; +use arrow::array::RecordBatch; +use arrow::datatypes::Schema; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::StreamWriter; +use bytes::Buf; +use crc32fast::Hasher; +use datafusion::common::DataFusionError; +use datafusion::error::Result; +use datafusion::physical_plan::metrics::Time; +use simd_adler32::Adler32; +use std::io::{Cursor, Seek, SeekFrom, Write}; + +#[derive(Debug, Clone)] +pub enum CompressionCodec { + None, + Lz4Frame, + Zstd(i32), + Snappy, +} + +#[derive(Clone)] +pub struct ShuffleBlockWriter { + codec: CompressionCodec, + header_bytes: Vec, +} + +impl ShuffleBlockWriter { + pub fn try_new(schema: &Schema, codec: CompressionCodec) -> Result { + let header_bytes = Vec::with_capacity(20); + let mut cursor = Cursor::new(header_bytes); + + // leave space for compressed message length + cursor.seek_relative(8)?; + + // write number of columns because JVM side needs to know how many addresses to allocate + let field_count = schema.fields().len(); + cursor.write_all(&field_count.to_le_bytes())?; + + // write compression codec to header + let codec_header = match &codec { + CompressionCodec::Snappy => b"SNAP", + CompressionCodec::Lz4Frame => b"LZ4_", + CompressionCodec::Zstd(_) => b"ZSTD", + CompressionCodec::None => b"NONE", + }; + cursor.write_all(codec_header)?; + + let header_bytes = cursor.into_inner(); + + Ok(Self { + codec, + header_bytes, + }) + } + + /// Writes given record batch as Arrow IPC bytes into given writer. + /// Returns number of bytes written. + pub fn write_batch( + &self, + batch: &RecordBatch, + output: &mut W, + ipc_time: &Time, + ) -> Result { + if batch.num_rows() == 0 { + return Ok(0); + } + + let mut timer = ipc_time.timer(); + let start_pos = output.stream_position()?; + + // write header + output.write_all(&self.header_bytes)?; + + let output = match &self.codec { + CompressionCodec::None => { + let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + arrow_writer.into_inner()? + } + CompressionCodec::Lz4Frame => { + let mut wtr = lz4_flex::frame::FrameEncoder::new(output); + let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + wtr.finish().map_err(|e| { + DataFusionError::Execution(format!("lz4 compression error: {e}")) + })? + } + + CompressionCodec::Zstd(level) => { + let encoder = zstd::Encoder::new(output, *level)?; + let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + let zstd_encoder = arrow_writer.into_inner()?; + zstd_encoder.finish()? + } + + CompressionCodec::Snappy => { + let mut wtr = snap::write::FrameEncoder::new(output); + let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; + arrow_writer.write(batch)?; + arrow_writer.finish()?; + wtr.into_inner().map_err(|e| { + DataFusionError::Execution(format!("snappy compression error: {e}")) + })? + } + }; + + // fill ipc length + let end_pos = output.stream_position()?; + let ipc_length = end_pos - start_pos - 8; + let max_size = i32::MAX as u64; + if ipc_length > max_size { + return Err(DataFusionError::Execution(format!( + "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ + Try reducing batch size or increasing compression level" + ))); + } + + // fill ipc length + output.seek(SeekFrom::Start(start_pos))?; + output.write_all(&ipc_length.to_le_bytes())?; + output.seek(SeekFrom::Start(end_pos))?; + + timer.stop(); + + Ok((end_pos - start_pos) as usize) + } +} + +pub fn read_ipc_compressed(bytes: &[u8]) -> Result { + match &bytes[0..4] { + b"SNAP" => { + let decoder = snap::read::FrameDecoder::new(&bytes[4..]); + let mut reader = + unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + b"LZ4_" => { + let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); + let mut reader = + unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + b"ZSTD" => { + let decoder = zstd::Decoder::new(&bytes[4..])?; + let mut reader = + unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + b"NONE" => { + let mut reader = + unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; + reader.next().unwrap().map_err(|e| e.into()) + } + other => Err(DataFusionError::Execution(format!( + "Failed to decode batch: invalid compression codec: {other:?}" + ))), + } +} + +/// Checksum algorithms for writing IPC bytes. +#[derive(Clone)] +pub(crate) enum Checksum { + /// CRC32 checksum algorithm. + CRC32(Hasher), + /// Adler32 checksum algorithm. + Adler32(Adler32), +} + +impl Checksum { + pub(crate) fn try_new(algo: i32, initial_opt: Option) -> CometResult { + match algo { + 0 => { + let hasher = if let Some(initial) = initial_opt { + Hasher::new_with_initial(initial) + } else { + Hasher::new() + }; + Ok(Checksum::CRC32(hasher)) + } + 1 => { + let hasher = if let Some(initial) = initial_opt { + // Note that Adler32 initial state is not zero. + // i.e., `Adler32::from_checksum(0)` is not the same as `Adler32::new()`. + Adler32::from_checksum(initial) + } else { + Adler32::new() + }; + Ok(Checksum::Adler32(hasher)) + } + _ => Err(CometError::Internal( + "Unsupported checksum algorithm".to_string(), + )), + } + } + + pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec>) -> CometResult<()> { + match self { + Checksum::CRC32(hasher) => { + std::io::Seek::seek(cursor, SeekFrom::Start(0))?; + hasher.update(cursor.chunk()); + Ok(()) + } + Checksum::Adler32(hasher) => { + std::io::Seek::seek(cursor, SeekFrom::Start(0))?; + hasher.write(cursor.chunk()); + Ok(()) + } + } + } + + pub(crate) fn finalize(self) -> u32 { + match self { + Checksum::CRC32(hasher) => hasher.finalize(), + Checksum::Adler32(hasher) => hasher.finish(), + } + } +} diff --git a/native/shuffle/src/comet_partitioning.rs b/native/shuffle/src/comet_partitioning.rs new file mode 100644 index 0000000000..c269539a62 --- /dev/null +++ b/native/shuffle/src/comet_partitioning.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::row::{OwnedRow, RowConverter}; +use datafusion::physical_expr::{LexOrdering, PhysicalExpr}; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub enum CometPartitioning { + SinglePartition, + /// Allocate rows based on a hash of one of more expressions and the specified number of + /// partitions. Args are 1) the expression to hash on, and 2) the number of partitions. + Hash(Vec>, usize), + /// Allocate rows based on the lexical order of one of more expressions and the specified number of + /// partitions. Args are 1) the LexOrdering to use to compare values and split into partitions, + /// 2) the number of partitions, 3) the RowConverter used to view incoming RecordBatches as Arrow + /// Rows for comparing to 4) OwnedRows that represent the boundaries of each partition, used with + /// LexOrdering to bin each value in the RecordBatch to a partition. + RangePartitioning(LexOrdering, usize, Arc, Vec), + /// Round robin partitioning. Distributes rows across partitions by sorting them by hash + /// (computed from columns) and then assigning partitions sequentially. Args are: + /// 1) number of partitions, 2) max columns to hash (0 means no limit). + RoundRobin(usize, usize), +} + +impl CometPartitioning { + pub fn partition_count(&self) -> usize { + use CometPartitioning::*; + match self { + SinglePartition => 1, + Hash(_, n) | RangePartitioning(_, n, _, _) | RoundRobin(n, _) => *n, + } + } +} + +pub(crate) fn pmod(hash: u32, n: usize) -> usize { + let hash = hash as i32; + let n = n as i32; + let r = hash % n; + let result = if r < 0 { (r + n) % n } else { r }; + result as usize +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pmod() { + let i: Vec = vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb]; + let result = i.into_iter().map(|i| pmod(i, 200)).collect::>(); + + // expected partition from Spark with n=200 + let expected = vec![69, 5, 193, 171, 115]; + assert_eq!(result, expected); + } +} diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs index b248758bc1..7c2fc8403f 100644 --- a/native/shuffle/src/lib.rs +++ b/native/shuffle/src/lib.rs @@ -14,3 +14,15 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. + +pub mod codec; +pub(crate) mod comet_partitioning; +pub(crate) mod metrics; +pub(crate) mod partitioners; +mod shuffle_writer; +pub mod spark_unsafe; +pub(crate) mod writers; + +pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter}; +pub use comet_partitioning::CometPartitioning; +pub use shuffle_writer::ShuffleWriterExec; diff --git a/native/shuffle/src/metrics.rs b/native/shuffle/src/metrics.rs new file mode 100644 index 0000000000..1aba4677db --- /dev/null +++ b/native/shuffle/src/metrics.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::physical_plan::metrics::{ + BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time, +}; + +pub(crate) struct ShufflePartitionerMetrics { + /// metrics + pub(crate) baseline: BaselineMetrics, + + /// Time to perform repartitioning + pub(crate) repart_time: Time, + + /// Time encoding batches to IPC format + pub(crate) encode_time: Time, + + /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics. + pub(crate) write_time: Time, + + /// Number of input batches + pub(crate) input_batches: Count, + + /// count of spills during the execution of the operator + pub(crate) spill_count: Count, + + /// total spilled bytes during the execution of the operator + pub(crate) spilled_bytes: Count, + + /// The original size of spilled data. Different to `spilled_bytes` because of compression. + pub(crate) data_size: Count, +} + +impl ShufflePartitionerMetrics { + pub(crate) fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { + Self { + baseline: BaselineMetrics::new(metrics, partition), + repart_time: MetricBuilder::new(metrics).subset_time("repart_time", partition), + encode_time: MetricBuilder::new(metrics).subset_time("encode_time", partition), + write_time: MetricBuilder::new(metrics).subset_time("write_time", partition), + input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), + spill_count: MetricBuilder::new(metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), + data_size: MetricBuilder::new(metrics).counter("data_size", partition), + } + } +} diff --git a/native/shuffle/src/partitioners/mod.rs b/native/shuffle/src/partitioners/mod.rs new file mode 100644 index 0000000000..a6d589677e --- /dev/null +++ b/native/shuffle/src/partitioners/mod.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod multi_partition; +mod partitioned_batch_iterator; +mod single_partition; + +use arrow::record_batch::RecordBatch; +use datafusion::common::Result; + +pub(crate) use multi_partition::MultiPartitionShuffleRepartitioner; +pub(crate) use partitioned_batch_iterator::PartitionedBatchIterator; +pub(crate) use single_partition::SinglePartitionShufflePartitioner; + +#[async_trait::async_trait] +pub(crate) trait ShufflePartitioner: Send + Sync { + /// Insert a batch into the partitioner + async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>; + /// Write shuffle data and shuffle index file to disk + fn shuffle_write(&mut self) -> Result<()>; +} diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs new file mode 100644 index 0000000000..c83f6fb9c8 --- /dev/null +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -0,0 +1,640 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::ShufflePartitionerMetrics; +use crate::partitioners::partitioned_batch_iterator::{ + PartitionedBatchIterator, PartitionedBatchesProducer, +}; +use crate::partitioners::ShufflePartitioner; +use crate::writers::{BufBatchWriter, PartitionWriter}; +use crate::{comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter}; +use datafusion_comet_common::tracing::{with_trace, with_trace_async}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::SchemaRef; +use datafusion::common::utils::proxy::VecAllocExt; +use datafusion::common::DataFusionError; +use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::physical_plan::metrics::Time; +use datafusion_comet_spark_expr::murmur3::create_murmur3_hashes; +use itertools::Itertools; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::fs::{File, OpenOptions}; +use std::io::{BufReader, BufWriter, Seek, Write}; +use std::sync::Arc; +use tokio::time::Instant; + +#[derive(Default)] +struct ScratchSpace { + /// Hashes for each row in the current batch. + hashes_buf: Vec, + /// Partition ids for each row in the current batch. + partition_ids: Vec, + /// The row indices of the rows in each partition. This array is conceptually divided into + /// partitions, where each partition contains the row indices of the rows in that partition. + /// The length of this array is the same as the number of rows in the batch. + partition_row_indices: Vec, + /// The start indices of partitions in partition_row_indices. partition_starts[K] and + /// partition_starts[K + 1] are the start and end indices of partition K in partition_row_indices. + /// The length of this array is 1 + the number of partitions. + partition_starts: Vec, +} + +impl ScratchSpace { + fn map_partition_ids_to_starts_and_indices( + &mut self, + num_output_partitions: usize, + num_rows: usize, + ) { + let partition_ids = &mut self.partition_ids[..num_rows]; + + // count each partition size, while leaving the last extra element as 0 + let partition_counters = &mut self.partition_starts; + partition_counters.resize(num_output_partitions + 1, 0); + partition_counters.fill(0); + partition_ids + .iter() + .for_each(|partition_id| partition_counters[*partition_id as usize] += 1); + + // accumulate partition counters into partition ends + // e.g. partition counter: [1, 3, 2, 1, 0] => [1, 4, 6, 7, 7] + let partition_ends = partition_counters; + let mut accum = 0; + partition_ends.iter_mut().for_each(|v| { + *v += accum; + accum = *v; + }); + + // calculate partition row indices and partition starts + // e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] will produce the following partition_row_indices + // and partition_starts arrays: + // + // partition_row_indices: [6, 1, 2, 3, 4, 5, 0] + // partition_starts: [0, 1, 4, 6, 7] + // + // partition_starts conceptually splits partition_row_indices into smaller slices. + // Each slice partition_row_indices[partition_starts[K]..partition_starts[K + 1]] contains the + // row indices of the input batch that are partitioned into partition K. For example, + // first partition 0 has one row index [6], partition 1 has row indices [1, 2, 3], etc. + let partition_row_indices = &mut self.partition_row_indices; + partition_row_indices.resize(num_rows, 0); + for (index, partition_id) in partition_ids.iter().enumerate().rev() { + partition_ends[*partition_id as usize] -= 1; + let end = partition_ends[*partition_id as usize]; + partition_row_indices[end as usize] = index as u32; + } + + // after calculating, partition ends become partition starts + } +} + +/// A partitioner that uses a hash function to partition data into multiple partitions +pub(crate) struct MultiPartitionShuffleRepartitioner { + output_data_file: String, + output_index_file: String, + buffered_batches: Vec, + partition_indices: Vec>, + partition_writers: Vec, + shuffle_block_writer: ShuffleBlockWriter, + /// Partitioning scheme to use + partitioning: CometPartitioning, + runtime: Arc, + metrics: ShufflePartitionerMetrics, + /// Reused scratch space for computing partition indices + scratch: ScratchSpace, + /// The configured batch size + batch_size: usize, + /// Reservation for repartitioning + reservation: MemoryReservation, + tracing_enabled: bool, + /// Size of the write buffer in bytes + write_buffer_size: usize, +} + +impl MultiPartitionShuffleRepartitioner { + #[allow(clippy::too_many_arguments)] + pub(crate) fn try_new( + partition: usize, + output_data_file: String, + output_index_file: String, + schema: SchemaRef, + partitioning: CometPartitioning, + metrics: ShufflePartitionerMetrics, + runtime: Arc, + batch_size: usize, + codec: CompressionCodec, + tracing_enabled: bool, + write_buffer_size: usize, + ) -> datafusion::common::Result { + let num_output_partitions = partitioning.partition_count(); + assert_ne!( + num_output_partitions, 1, + "Use SinglePartitionShufflePartitioner for 1 output partition." + ); + + // Vectors in the scratch space will be filled with valid values before being used, this + // initialization code is simply initializing the vectors to the desired size. + // The initial values are not used. + let scratch = ScratchSpace { + hashes_buf: match partitioning { + // Allocate hashes_buf for hash and round robin partitioning. + // Round robin hashes all columns to achieve even, deterministic distribution. + CometPartitioning::Hash(_, _) | CometPartitioning::RoundRobin(_, _) => { + vec![0; batch_size] + } + _ => vec![], + }, + partition_ids: vec![0; batch_size], + partition_row_indices: vec![0; batch_size], + partition_starts: vec![0; num_output_partitions + 1], + }; + + let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; + + let partition_writers = (0..num_output_partitions) + .map(|_| PartitionWriter::try_new(shuffle_block_writer.clone())) + .collect::>>()?; + + let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{partition}]")) + .with_can_spill(true) + .register(&runtime.memory_pool); + + Ok(Self { + output_data_file, + output_index_file, + buffered_batches: vec![], + partition_indices: vec![vec![]; num_output_partitions], + partition_writers, + shuffle_block_writer, + partitioning, + runtime, + metrics, + scratch, + batch_size, + reservation, + tracing_enabled, + write_buffer_size, + }) + } + + /// Shuffles rows in input batch into corresponding partition buffer. + /// This function first calculates hashes for rows and then takes rows in same + /// partition as a record batch which is appended into partition buffer. + /// This should not be called directly. Use `insert_batch` instead. + async fn partitioning_batch(&mut self, input: RecordBatch) -> datafusion::common::Result<()> { + if input.num_rows() == 0 { + // skip empty batch + return Ok(()); + } + + if input.num_rows() > self.batch_size { + return Err(DataFusionError::Internal( + "Input batch size exceeds configured batch size. Call `insert_batch` instead." + .to_string(), + )); + } + + // Update data size metric + self.metrics.data_size.add(input.get_array_memory_size()); + + // NOTE: in shuffle writer exec, the output_rows metrics represents the + // number of rows those are written to output data file. + self.metrics.baseline.record_output(input.num_rows()); + + match &self.partitioning { + CometPartitioning::Hash(exprs, num_output_partitions) => { + let mut scratch = std::mem::take(&mut self.scratch); + let (partition_starts, partition_row_indices): (&Vec, &Vec) = { + let mut timer = self.metrics.repart_time.timer(); + + // Evaluate partition expressions to get rows to apply partitioning scheme. + let arrays = exprs + .iter() + .map(|expr| expr.evaluate(&input)?.into_array(input.num_rows())) + .collect::>>()?; + + let num_rows = arrays[0].len(); + + // Use identical seed as Spark hash partitioning. + let hashes_buf = &mut scratch.hashes_buf[..num_rows]; + hashes_buf.fill(42_u32); + + // Generate partition ids for every row. + { + // Hash arrays and compute partition ids based on number of partitions. + let partition_ids = &mut scratch.partition_ids[..num_rows]; + create_murmur3_hashes(&arrays, hashes_buf)? + .iter() + .enumerate() + .for_each(|(idx, hash)| { + partition_ids[idx] = + comet_partitioning::pmod(*hash, *num_output_partitions) as u32; + }); + } + + // We now have partition ids for every input row, map that to partition starts + // and partition indices to eventually right these rows to partition buffers. + scratch + .map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows); + + timer.stop(); + Ok::<(&Vec, &Vec), DataFusionError>(( + &scratch.partition_starts, + &scratch.partition_row_indices, + )) + }?; + + self.buffer_partitioned_batch_may_spill( + input, + partition_row_indices, + partition_starts, + ) + .await?; + self.scratch = scratch; + } + CometPartitioning::RangePartitioning( + lex_ordering, + num_output_partitions, + row_converter, + bounds, + ) => { + let mut scratch = std::mem::take(&mut self.scratch); + let (partition_starts, partition_row_indices): (&Vec, &Vec) = { + let mut timer = self.metrics.repart_time.timer(); + + // Evaluate partition expressions for values to apply partitioning scheme on. + let arrays = lex_ordering + .iter() + .map(|expr| expr.expr.evaluate(&input)?.into_array(input.num_rows())) + .collect::>>()?; + + let num_rows = arrays[0].len(); + + // Generate partition ids for every row, first by converting the partition + // arrays to Rows, and then doing binary search for each Row against the + // bounds Rows. + { + let row_batch = row_converter.convert_columns(arrays.as_slice())?; + let partition_ids = &mut scratch.partition_ids[..num_rows]; + + row_batch.iter().enumerate().for_each(|(row_idx, row)| { + partition_ids[row_idx] = bounds + .as_slice() + .partition_point(|bound| bound.row() <= row) + as u32 + }); + } + + // We now have partition ids for every input row, map that to partition starts + // and partition indices to eventually right these rows to partition buffers. + scratch + .map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows); + + timer.stop(); + Ok::<(&Vec, &Vec), DataFusionError>(( + &scratch.partition_starts, + &scratch.partition_row_indices, + )) + }?; + + self.buffer_partitioned_batch_may_spill( + input, + partition_row_indices, + partition_starts, + ) + .await?; + self.scratch = scratch; + } + CometPartitioning::RoundRobin(num_output_partitions, max_hash_columns) => { + // Comet implements "round robin" as hash partitioning on columns. + // This achieves the same goal as Spark's round robin (even distribution + // without semantic grouping) while being deterministic for fault tolerance. + // + // Note: This produces different partition assignments than Spark's round robin, + // which sorts by UnsafeRow binary representation before assigning partitions. + // However, both approaches provide even distribution and determinism. + let mut scratch = std::mem::take(&mut self.scratch); + let (partition_starts, partition_row_indices): (&Vec, &Vec) = { + let mut timer = self.metrics.repart_time.timer(); + + let num_rows = input.num_rows(); + + // Collect columns for hashing, respecting max_hash_columns limit + // max_hash_columns of 0 means no limit (hash all columns) + // Negative values are normalized to 0 in the planner + let num_columns_to_hash = if *max_hash_columns == 0 { + input.num_columns() + } else { + (*max_hash_columns).min(input.num_columns()) + }; + let columns_to_hash: Vec = (0..num_columns_to_hash) + .map(|i| Arc::clone(input.column(i))) + .collect(); + + // Use identical seed as Spark hash partitioning. + let hashes_buf = &mut scratch.hashes_buf[..num_rows]; + hashes_buf.fill(42_u32); + + // Compute hash for selected columns + create_murmur3_hashes(&columns_to_hash, hashes_buf)?; + + // Assign partition IDs based on hash (same as hash partitioning) + let partition_ids = &mut scratch.partition_ids[..num_rows]; + hashes_buf.iter().enumerate().for_each(|(idx, hash)| { + partition_ids[idx] = + comet_partitioning::pmod(*hash, *num_output_partitions) as u32; + }); + + // We now have partition ids for every input row, map that to partition starts + // and partition indices to eventually write these rows to partition buffers. + scratch + .map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows); + + timer.stop(); + Ok::<(&Vec, &Vec), DataFusionError>(( + &scratch.partition_starts, + &scratch.partition_row_indices, + )) + }?; + + self.buffer_partitioned_batch_may_spill( + input, + partition_row_indices, + partition_starts, + ) + .await?; + self.scratch = scratch; + } + other => { + // this should be unreachable as long as the validation logic + // in the constructor is kept up-to-date + return Err(DataFusionError::NotImplemented(format!( + "Unsupported shuffle partitioning scheme {other:?}" + ))); + } + } + Ok(()) + } + + async fn buffer_partitioned_batch_may_spill( + &mut self, + input: RecordBatch, + partition_row_indices: &[u32], + partition_starts: &[u32], + ) -> datafusion::common::Result<()> { + let mut mem_growth: usize = input.get_array_memory_size(); + let buffered_partition_idx = self.buffered_batches.len() as u32; + self.buffered_batches.push(input); + + // partition_starts conceptually slices partition_row_indices into smaller slices, + // each slice contains the indices of rows in input that will go into the corresponding + // partition. The following loop iterates over the slices and put the row indices into + // the indices array of the corresponding partition. + for (partition_id, (&start, &end)) in partition_starts + .iter() + .tuple_windows() + .enumerate() + .filter(|(_, (start, end))| start < end) + { + let row_indices = &partition_row_indices[start as usize..end as usize]; + + // Put row indices for the current partition into the indices array of that partition. + // This indices array will be used for calling interleave_record_batch to produce + // shuffled batches. + let indices = &mut self.partition_indices[partition_id]; + let before_size = indices.allocated_size(); + indices.reserve(row_indices.len()); + for row_idx in row_indices { + indices.push((buffered_partition_idx, *row_idx)); + } + let after_size = indices.allocated_size(); + mem_growth += after_size.saturating_sub(before_size); + } + + if self.reservation.try_grow(mem_growth).is_err() { + self.spill()?; + } + + Ok(()) + } + + fn shuffle_write_partition( + partition_iter: &mut PartitionedBatchIterator, + shuffle_block_writer: &mut ShuffleBlockWriter, + output_data: &mut BufWriter, + encode_time: &Time, + write_time: &Time, + write_buffer_size: usize, + batch_size: usize, + ) -> datafusion::common::Result<()> { + let mut buf_batch_writer = BufBatchWriter::new( + shuffle_block_writer, + output_data, + write_buffer_size, + batch_size, + ); + for batch in partition_iter { + let batch = batch?; + buf_batch_writer.write(&batch, encode_time, write_time)?; + } + buf_batch_writer.flush(encode_time, write_time)?; + Ok(()) + } + + fn used(&self) -> usize { + self.reservation.size() + } + + fn spilled_bytes(&self) -> usize { + self.metrics.spilled_bytes.value() + } + + fn spill_count(&self) -> usize { + self.metrics.spill_count.value() + } + + fn data_size(&self) -> usize { + self.metrics.data_size.value() + } + + /// This function transfers the ownership of the buffered batches and partition indices from the + /// ShuffleRepartitioner to a new PartitionedBatches struct. The returned PartitionedBatches struct + /// can be used to produce shuffled batches. + fn partitioned_batches(&mut self) -> PartitionedBatchesProducer { + let num_output_partitions = self.partition_indices.len(); + let buffered_batches = std::mem::take(&mut self.buffered_batches); + // let indices = std::mem::take(&mut self.partition_indices); + let indices = std::mem::replace( + &mut self.partition_indices, + vec![vec![]; num_output_partitions], + ); + PartitionedBatchesProducer::new(buffered_batches, indices, self.batch_size) + } + + pub(crate) fn spill(&mut self) -> datafusion::common::Result<()> { + log::info!( + "ShuffleRepartitioner spilling shuffle data of {} to disk while inserting ({} time(s) so far)", + self.used(), + self.spill_count() + ); + + // we could always get a chance to free some memory as long as we are holding some + if self.buffered_batches.is_empty() { + return Ok(()); + } + + with_trace("shuffle_spill", self.tracing_enabled, || { + let num_output_partitions = self.partition_writers.len(); + let mut partitioned_batches = self.partitioned_batches(); + let mut spilled_bytes = 0; + + for partition_id in 0..num_output_partitions { + let partition_writer = &mut self.partition_writers[partition_id]; + let mut iter = partitioned_batches.produce(partition_id); + spilled_bytes += partition_writer.spill( + &mut iter, + &self.runtime, + &self.metrics, + self.write_buffer_size, + self.batch_size, + )?; + } + + self.reservation.free(); + self.metrics.spill_count.add(1); + self.metrics.spilled_bytes.add(spilled_bytes); + Ok(()) + }) + } + + #[cfg(test)] + pub(crate) fn partition_writers(&self) -> &[PartitionWriter] { + &self.partition_writers + } +} + +#[async_trait::async_trait] +impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { + /// Shuffles rows in input batch into corresponding partition buffer. + /// This function will slice input batch according to configured batch size and then + /// shuffle rows into corresponding partition buffer. + async fn insert_batch(&mut self, batch: RecordBatch) -> datafusion::common::Result<()> { + with_trace_async("shuffle_insert_batch", self.tracing_enabled, || async { + let start_time = Instant::now(); + let mut start = 0; + while start < batch.num_rows() { + let end = (start + self.batch_size).min(batch.num_rows()); + let batch = batch.slice(start, end - start); + self.partitioning_batch(batch).await?; + start = end; + } + self.metrics.input_batches.add(1); + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + Ok(()) + }) + .await + } + + /// Writes buffered shuffled record batches into Arrow IPC bytes. + fn shuffle_write(&mut self) -> datafusion::common::Result<()> { + with_trace("shuffle_write", self.tracing_enabled, || { + let start_time = Instant::now(); + + let mut partitioned_batches = self.partitioned_batches(); + let num_output_partitions = self.partition_indices.len(); + let mut offsets = vec![0; num_output_partitions + 1]; + + let data_file = self.output_data_file.clone(); + let index_file = self.output_index_file.clone(); + + let output_data = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(data_file) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; + + let mut output_data = BufWriter::new(output_data); + + #[allow(clippy::needless_range_loop)] + for i in 0..num_output_partitions { + offsets[i] = output_data.stream_position()?; + + // if we wrote a spill file for this partition then copy the + // contents into the shuffle file + if let Some(spill_path) = self.partition_writers[i].path() { + let mut spill_file = BufReader::new(File::open(spill_path)?); + let mut write_timer = self.metrics.write_time.timer(); + std::io::copy(&mut spill_file, &mut output_data)?; + write_timer.stop(); + } + + // Write in memory batches to output data file + let mut partition_iter = partitioned_batches.produce(i); + Self::shuffle_write_partition( + &mut partition_iter, + &mut self.shuffle_block_writer, + &mut output_data, + &self.metrics.encode_time, + &self.metrics.write_time, + self.write_buffer_size, + self.batch_size, + )?; + } + + let mut write_timer = self.metrics.write_time.timer(); + output_data.flush()?; + write_timer.stop(); + + // add one extra offset at last to ease partition length computation + offsets[num_output_partitions] = output_data.stream_position()?; + + let mut write_timer = self.metrics.write_time.timer(); + let mut output_index = + BufWriter::new(File::create(index_file).map_err(|e| { + DataFusionError::Execution(format!("shuffle write error: {e:?}")) + })?); + for offset in offsets { + output_index.write_all(&(offset as i64).to_le_bytes()[..])?; + } + output_index.flush()?; + write_timer.stop(); + + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + + Ok(()) + }) + } +} + +impl Debug for MultiPartitionShuffleRepartitioner { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("ShuffleRepartitioner") + .field("memory_used", &self.used()) + .field("spilled_bytes", &self.spilled_bytes()) + .field("spilled_count", &self.spill_count()) + .field("data_size", &self.data_size()) + .finish() + } +} diff --git a/native/shuffle/src/partitioners/partitioned_batch_iterator.rs b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs new file mode 100644 index 0000000000..77010938cd --- /dev/null +++ b/native/shuffle/src/partitioners/partitioned_batch_iterator.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::RecordBatch; +use arrow::compute::interleave_record_batch; +use datafusion::common::DataFusionError; + +/// A helper struct to produce shuffled batches. +/// This struct takes ownership of the buffered batches and partition indices from the +/// ShuffleRepartitioner, and provides an iterator over the batches in the specified partitions. +pub(super) struct PartitionedBatchesProducer { + buffered_batches: Vec, + partition_indices: Vec>, + batch_size: usize, +} + +impl PartitionedBatchesProducer { + pub(super) fn new( + buffered_batches: Vec, + indices: Vec>, + batch_size: usize, + ) -> Self { + Self { + partition_indices: indices, + buffered_batches, + batch_size, + } + } + + pub(super) fn produce(&mut self, partition_id: usize) -> PartitionedBatchIterator<'_> { + PartitionedBatchIterator::new( + &self.partition_indices[partition_id], + &self.buffered_batches, + self.batch_size, + ) + } +} + +pub(crate) struct PartitionedBatchIterator<'a> { + record_batches: Vec<&'a RecordBatch>, + batch_size: usize, + indices: Vec<(usize, usize)>, + pos: usize, +} + +impl<'a> PartitionedBatchIterator<'a> { + fn new( + indices: &'a [(u32, u32)], + buffered_batches: &'a [RecordBatch], + batch_size: usize, + ) -> Self { + if indices.is_empty() { + // Avoid unnecessary allocations when the partition is empty + return Self { + record_batches: vec![], + batch_size, + indices: vec![], + pos: 0, + }; + } + let record_batches = buffered_batches.iter().collect::>(); + let current_indices = indices + .iter() + .map(|(i_batch, i_row)| (*i_batch as usize, *i_row as usize)) + .collect::>(); + Self { + record_batches, + batch_size, + indices: current_indices, + pos: 0, + } + } +} + +impl Iterator for PartitionedBatchIterator<'_> { + type Item = datafusion::common::Result; + + fn next(&mut self) -> Option { + if self.pos >= self.indices.len() { + return None; + } + + let indices_end = std::cmp::min(self.pos + self.batch_size, self.indices.len()); + let indices = &self.indices[self.pos..indices_end]; + match interleave_record_batch(&self.record_batches, indices) { + Ok(batch) => { + self.pos = indices_end; + Some(Ok(batch)) + } + Err(e) => Some(Err(DataFusionError::ArrowError( + Box::from(e), + Some(DataFusionError::get_back_trace()), + ))), + } + } +} diff --git a/native/shuffle/src/partitioners/single_partition.rs b/native/shuffle/src/partitioners/single_partition.rs new file mode 100644 index 0000000000..5801ef613b --- /dev/null +++ b/native/shuffle/src/partitioners/single_partition.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::ShufflePartitionerMetrics; +use crate::partitioners::ShufflePartitioner; +use crate::writers::BufBatchWriter; +use crate::{CompressionCodec, ShuffleBlockWriter}; +use arrow::array::RecordBatch; +use arrow::datatypes::SchemaRef; +use datafusion::common::DataFusionError; +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Write}; +use tokio::time::Instant; + +/// A partitioner that writes all shuffle data to a single file and a single index file +pub(crate) struct SinglePartitionShufflePartitioner { + // output_data_file: File, + output_data_writer: BufBatchWriter, + output_index_path: String, + /// Batches that are smaller than the batch size and to be concatenated + buffered_batches: Vec, + /// Number of rows in the concatenating batches + num_buffered_rows: usize, + /// Metrics for the repartitioner + metrics: ShufflePartitionerMetrics, + /// The configured batch size + batch_size: usize, +} + +impl SinglePartitionShufflePartitioner { + pub(crate) fn try_new( + output_data_path: String, + output_index_path: String, + schema: SchemaRef, + metrics: ShufflePartitionerMetrics, + batch_size: usize, + codec: CompressionCodec, + write_buffer_size: usize, + ) -> datafusion::common::Result { + let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; + + let output_data_file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(output_data_path)?; + + let output_data_writer = BufBatchWriter::new( + shuffle_block_writer, + output_data_file, + write_buffer_size, + batch_size, + ); + + Ok(Self { + output_data_writer, + output_index_path, + buffered_batches: vec![], + num_buffered_rows: 0, + metrics, + batch_size, + }) + } + + /// Add a batch to the buffer of the partitioner, these buffered batches will be concatenated + /// and written to the output data file when the number of rows in the buffer reaches the batch size. + fn add_buffered_batch(&mut self, batch: RecordBatch) { + self.num_buffered_rows += batch.num_rows(); + self.buffered_batches.push(batch); + } + + /// Consumes buffered batches and return a concatenated batch if successful + fn concat_buffered_batches(&mut self) -> datafusion::common::Result> { + if self.buffered_batches.is_empty() { + Ok(None) + } else if self.buffered_batches.len() == 1 { + let batch = self.buffered_batches.remove(0); + self.num_buffered_rows = 0; + Ok(Some(batch)) + } else { + let schema = &self.buffered_batches[0].schema(); + match arrow::compute::concat_batches(schema, self.buffered_batches.iter()) { + Ok(concatenated) => { + self.buffered_batches.clear(); + self.num_buffered_rows = 0; + Ok(Some(concatenated)) + } + Err(e) => Err(DataFusionError::ArrowError( + Box::from(e), + Some(DataFusionError::get_back_trace()), + )), + } + } + } +} + +#[async_trait::async_trait] +impl ShufflePartitioner for SinglePartitionShufflePartitioner { + async fn insert_batch(&mut self, batch: RecordBatch) -> datafusion::common::Result<()> { + let start_time = Instant::now(); + let num_rows = batch.num_rows(); + + if num_rows > 0 { + self.metrics.data_size.add(batch.get_array_memory_size()); + self.metrics.baseline.record_output(num_rows); + + if num_rows >= self.batch_size || num_rows + self.num_buffered_rows > self.batch_size { + let concatenated_batch = self.concat_buffered_batches()?; + + // Write the concatenated buffered batch + if let Some(batch) = concatenated_batch { + self.output_data_writer.write( + &batch, + &self.metrics.encode_time, + &self.metrics.write_time, + )?; + } + + if num_rows >= self.batch_size { + // Write the new batch + self.output_data_writer.write( + &batch, + &self.metrics.encode_time, + &self.metrics.write_time, + )?; + } else { + // Add the new batch to the buffer + self.add_buffered_batch(batch); + } + } else { + self.add_buffered_batch(batch); + } + } + + self.metrics.input_batches.add(1); + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + Ok(()) + } + + fn shuffle_write(&mut self) -> datafusion::common::Result<()> { + let start_time = Instant::now(); + let concatenated_batch = self.concat_buffered_batches()?; + + // Write the concatenated buffered batch + if let Some(batch) = concatenated_batch { + self.output_data_writer.write( + &batch, + &self.metrics.encode_time, + &self.metrics.write_time, + )?; + } + self.output_data_writer + .flush(&self.metrics.encode_time, &self.metrics.write_time)?; + + // Write index file. It should only contain 2 entries: 0 and the total number of bytes written + let index_file = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(self.output_index_path.clone()) + .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; + let mut index_buf_writer = BufWriter::new(index_file); + let data_file_length = self.output_data_writer.writer_stream_position()?; + for offset in [0, data_file_length] { + index_buf_writer.write_all(&(offset as i64).to_le_bytes()[..])?; + } + index_buf_writer.flush()?; + + self.metrics + .baseline + .elapsed_compute() + .add_duration(start_time.elapsed()); + Ok(()) + } +} diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs new file mode 100644 index 0000000000..4b3f08a826 --- /dev/null +++ b/native/shuffle/src/shuffle_writer.rs @@ -0,0 +1,696 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the External shuffle repartition plan. + +use crate::metrics::ShufflePartitionerMetrics; +use crate::partitioners::{ + MultiPartitionShuffleRepartitioner, ShufflePartitioner, SinglePartitionShufflePartitioner, +}; +use crate::{CometPartitioning, CompressionCodec}; +use datafusion_comet_common::tracing::with_trace_async; +use async_trait::async_trait; +use datafusion::common::exec_datafusion_err; +use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::EmptyRecordBatchStream; +use datafusion::{ + arrow::{datatypes::SchemaRef, error::ArrowError}, + error::Result, + execution::context::TaskContext, + physical_plan::{ + metrics::{ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, + Statistics, + }, +}; +use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use std::{ + any::Any, + fmt, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +/// The shuffle writer operator maps each input partition to M output partitions based on a +/// partitioning scheme. No guarantees are made about the order of the resulting partitions. +#[derive(Debug)] +pub struct ShuffleWriterExec { + /// Input execution plan + input: Arc, + /// Partitioning scheme to use + partitioning: CometPartitioning, + /// Output data file path + output_data_file: String, + /// Output index file path + output_index_file: String, + /// Metrics + metrics: ExecutionPlanMetricsSet, + /// Cache for expensive-to-compute plan properties + cache: PlanProperties, + /// The compression codec to use when compressing shuffle blocks + codec: CompressionCodec, + tracing_enabled: bool, + /// Size of the write buffer in bytes + write_buffer_size: usize, +} + +impl ShuffleWriterExec { + /// Create a new ShuffleWriterExec + #[allow(clippy::too_many_arguments)] + pub fn try_new( + input: Arc, + partitioning: CometPartitioning, + codec: CompressionCodec, + output_data_file: String, + output_index_file: String, + tracing_enabled: bool, + write_buffer_size: usize, + ) -> Result { + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&input.schema())), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(ShuffleWriterExec { + input, + partitioning, + metrics: ExecutionPlanMetricsSet::new(), + output_data_file, + output_index_file, + cache, + codec, + tracing_enabled, + write_buffer_size, + }) + } +} + +impl DisplayAs for ShuffleWriterExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "ShuffleWriterExec: partitioning={:?}, compression={:?}", + self.partitioning, self.codec + ) + } + DisplayFormatType::TreeRender => unimplemented!(), + } + } +} + +#[async_trait] +impl ExecutionPlan for ShuffleWriterExec { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ShuffleWriterExec" + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + self.input.partition_statistics(None) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + self.input.schema() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match children.len() { + 1 => Ok(Arc::new(ShuffleWriterExec::try_new( + Arc::clone(&children[0]), + self.partitioning.clone(), + self.codec.clone(), + self.output_data_file.clone(), + self.output_index_file.clone(), + self.tracing_enabled, + self.write_buffer_size, + )?)), + _ => panic!("ShuffleWriterExec wrong number of children"), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let input = self.input.execute(partition, Arc::clone(&context))?; + let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0); + + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once( + external_shuffle( + input, + partition, + self.output_data_file.clone(), + self.output_index_file.clone(), + self.partitioning.clone(), + metrics, + context, + self.codec.clone(), + self.tracing_enabled, + self.write_buffer_size, + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e))), + ) + .try_flatten(), + ))) + } +} + +#[allow(clippy::too_many_arguments)] +async fn external_shuffle( + mut input: SendableRecordBatchStream, + partition: usize, + output_data_file: String, + output_index_file: String, + partitioning: CometPartitioning, + metrics: ShufflePartitionerMetrics, + context: Arc, + codec: CompressionCodec, + tracing_enabled: bool, + write_buffer_size: usize, +) -> Result { + with_trace_async("external_shuffle", tracing_enabled, || async { + let schema = input.schema(); + + let mut repartitioner: Box = match &partitioning { + any if any.partition_count() == 1 => { + Box::new(SinglePartitionShufflePartitioner::try_new( + output_data_file, + output_index_file, + Arc::clone(&schema), + metrics, + context.session_config().batch_size(), + codec, + write_buffer_size, + )?) + } + _ => Box::new(MultiPartitionShuffleRepartitioner::try_new( + partition, + output_data_file, + output_index_file, + Arc::clone(&schema), + partitioning, + metrics, + context.runtime_env(), + context.session_config().batch_size(), + codec, + tracing_enabled, + write_buffer_size, + )?), + }; + + while let Some(batch) = input.next().await { + // Await the repartitioner to insert the batch and shuffle the rows + // into the corresponding partition buffer. + // Otherwise, pull the next batch from the input stream might overwrite the + // current batch in the repartitioner. + repartitioner + .insert_batch(batch?) + .await + .map_err(|err| exec_datafusion_err!("Error inserting batch: {err}"))?; + } + + repartitioner + .shuffle_write() + .map_err(|err| exec_datafusion_err!("Error in shuffle write: {err}"))?; + + // shuffle writer always has empty output + Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) as SendableRecordBatchStream) + }) + .await +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{read_ipc_compressed, ShuffleBlockWriter}; + use arrow::array::{Array, StringArray, StringBuilder}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow::row::{RowConverter, SortField}; + use datafusion::datasource::memory::MemorySourceConfig; + use datafusion::datasource::source::DataSourceExec; + use datafusion::execution::config::SessionConfig; + use datafusion::execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; + use datafusion::physical_expr::expressions::{col, Column}; + use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; + use datafusion::physical_plan::common::collect; + use datafusion::physical_plan::metrics::Time; + use datafusion::prelude::SessionContext; + use itertools::Itertools; + use std::io::Cursor; + use tokio::runtime::Runtime; + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn roundtrip_ipc() { + let batch = create_batch(8192); + for codec in &[ + CompressionCodec::None, + CompressionCodec::Zstd(1), + CompressionCodec::Snappy, + CompressionCodec::Lz4Frame, + ] { + let mut output = vec![]; + let mut cursor = Cursor::new(&mut output); + let writer = + ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone()).unwrap(); + let length = writer + .write_batch(&batch, &mut cursor, &Time::default()) + .unwrap(); + assert_eq!(length, output.len()); + + let ipc_without_length_prefix = &output[16..]; + let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); + assert_eq!(batch, batch2); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_single_partition_shuffle_writer() { + shuffle_write_test(1000, 100, 1, None); + shuffle_write_test(10000, 10, 1, None); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_insert_larger_batch() { + shuffle_write_test(10000, 1, 16, None); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_insert_smaller_batch() { + shuffle_write_test(1000, 1, 16, None); + shuffle_write_test(1000, 10, 16, None); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_large_number_of_partitions() { + shuffle_write_test(10000, 10, 200, Some(10 * 1024 * 1024)); + shuffle_write_test(10000, 10, 2000, Some(10 * 1024 * 1024)); + } + + #[test] + #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` + fn test_large_number_of_partitions_spilling() { + shuffle_write_test(10000, 100, 200, Some(10 * 1024 * 1024)); + } + + #[tokio::test] + async fn shuffle_partitioner_memory() { + let batch = create_batch(900); + assert_eq!(8316, batch.get_array_memory_size()); // Not stable across Arrow versions + + let memory_limit = 512 * 1024; + let num_partitions = 2; + let runtime_env = create_runtime(memory_limit); + let metrics_set = ExecutionPlanMetricsSet::new(); + let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new( + 0, + "/tmp/data.out".to_string(), + "/tmp/index.out".to_string(), + batch.schema(), + CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions), + ShufflePartitionerMetrics::new(&metrics_set, 0), + runtime_env, + 1024, + CompressionCodec::Lz4Frame, + false, + 1024 * 1024, // write_buffer_size: 1MB default + ) + .unwrap(); + + repartitioner.insert_batch(batch.clone()).await.unwrap(); + + { + let partition_writers = repartitioner.partition_writers(); + assert_eq!(partition_writers.len(), 2); + + assert!(!partition_writers[0].has_spill_file()); + assert!(!partition_writers[1].has_spill_file()); + } + + repartitioner.spill().unwrap(); + + // after spill, there should be spill files + { + let partition_writers = repartitioner.partition_writers(); + assert!(partition_writers[0].has_spill_file()); + assert!(partition_writers[1].has_spill_file()); + } + + // insert another batch after spilling + repartitioner.insert_batch(batch.clone()).await.unwrap(); + } + + fn create_runtime(memory_limit: usize) -> Arc { + Arc::new( + RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .build() + .unwrap(), + ) + } + + fn shuffle_write_test( + batch_size: usize, + num_batches: usize, + num_partitions: usize, + memory_limit: Option, + ) { + let batch = create_batch(batch_size); + + let lex_ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("a", batch.schema().as_ref()).unwrap(), + )]) + .unwrap(); + + let sort_fields: Vec = batch + .columns() + .iter() + .zip(&lex_ordering) + .map(|(array, sort_expr)| { + SortField::new_with_options(array.data_type().clone(), sort_expr.options) + }) + .collect(); + let row_converter = RowConverter::new(sort_fields).unwrap(); + + let owned_rows = if num_partitions == 1 { + vec![] + } else { + // Determine range boundaries based on create_batch implementation. We just divide the + // domain of values in the batch equally to find partition bounds. + let bounds_strings = { + let mut boundaries = Vec::with_capacity(num_partitions - 1); + let step = batch_size as f64 / num_partitions as f64; + + for i in 1..(num_partitions) { + boundaries.push(Some((step * i as f64).round().to_string())); + } + boundaries + }; + let bounds_array: Arc = Arc::new(StringArray::from(bounds_strings)); + let bounds_rows = row_converter + .convert_columns(vec![bounds_array].as_slice()) + .unwrap(); + + let owned_rows_vec = bounds_rows.iter().map(|row| row.owned()).collect_vec(); + owned_rows_vec + }; + + for partitioning in [ + CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions), + CometPartitioning::RangePartitioning( + lex_ordering, + num_partitions, + Arc::new(row_converter), + owned_rows, + ), + CometPartitioning::RoundRobin(num_partitions, 0), + ] { + let batches = (0..num_batches).map(|_| batch.clone()).collect::>(); + + let partitions = &[batches]; + let exec = ShuffleWriterExec::try_new( + Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(), + ))), + partitioning, + CompressionCodec::Zstd(1), + "/tmp/data.out".to_string(), + "/tmp/index.out".to_string(), + false, + 1024 * 1024, // write_buffer_size: 1MB default + ) + .unwrap(); + + // 10MB memory should be enough for running this test + let config = SessionConfig::new(); + let mut runtime_env_builder = RuntimeEnvBuilder::new(); + runtime_env_builder = match memory_limit { + Some(limit) => runtime_env_builder.with_memory_limit(limit, 1.0), + None => runtime_env_builder, + }; + let runtime_env = Arc::new(runtime_env_builder.build().unwrap()); + let ctx = SessionContext::new_with_config_rt(config, runtime_env); + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).unwrap(); + let rt = Runtime::new().unwrap(); + rt.block_on(collect(stream)).unwrap(); + } + } + + fn create_batch(batch_size: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..batch_size { + b.append_value(format!("{i}")); + } + let array = b.finish(); + RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() + } + + #[test] + #[cfg_attr(miri, ignore)] + fn test_round_robin_deterministic() { + // Test that round robin partitioning produces identical results when run multiple times + use std::fs; + use std::io::Read; + + let batch_size = 1000; + let num_batches = 10; + let num_partitions = 8; + + let batch = create_batch(batch_size); + let batches = (0..num_batches).map(|_| batch.clone()).collect::>(); + + // Run shuffle twice and compare results + for run in 0..2 { + let data_file = format!("/tmp/rr_data_{}.out", run); + let index_file = format!("/tmp/rr_index_{}.out", run); + + let partitions = std::slice::from_ref(&batches); + let exec = ShuffleWriterExec::try_new( + Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(), + ))), + CometPartitioning::RoundRobin(num_partitions, 0), + CompressionCodec::Zstd(1), + data_file.clone(), + index_file.clone(), + false, + 1024 * 1024, + ) + .unwrap(); + + let config = SessionConfig::new(); + let runtime_env = Arc::new( + RuntimeEnvBuilder::new() + .with_memory_limit(10 * 1024 * 1024, 1.0) + .build() + .unwrap(), + ); + let session_ctx = Arc::new(SessionContext::new_with_config_rt(config, runtime_env)); + let task_ctx = Arc::new(TaskContext::from(session_ctx.as_ref())); + + // Execute the shuffle + futures::executor::block_on(async { + let mut stream = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); + while stream.next().await.is_some() {} + }); + + if run == 1 { + // Compare data files + let mut data0 = Vec::new(); + fs::File::open("/tmp/rr_data_0.out") + .unwrap() + .read_to_end(&mut data0) + .unwrap(); + let mut data1 = Vec::new(); + fs::File::open("/tmp/rr_data_1.out") + .unwrap() + .read_to_end(&mut data1) + .unwrap(); + assert_eq!( + data0, data1, + "Round robin shuffle data should be identical across runs" + ); + + // Compare index files + let mut index0 = Vec::new(); + fs::File::open("/tmp/rr_index_0.out") + .unwrap() + .read_to_end(&mut index0) + .unwrap(); + let mut index1 = Vec::new(); + fs::File::open("/tmp/rr_index_1.out") + .unwrap() + .read_to_end(&mut index1) + .unwrap(); + assert_eq!( + index0, index1, + "Round robin shuffle index should be identical across runs" + ); + } + } + + // Clean up + let _ = fs::remove_file("/tmp/rr_data_0.out"); + let _ = fs::remove_file("/tmp/rr_index_0.out"); + let _ = fs::remove_file("/tmp/rr_data_1.out"); + let _ = fs::remove_file("/tmp/rr_index_1.out"); + } + + /// Test that batch coalescing in BufBatchWriter reduces output size by + /// writing fewer, larger IPC blocks instead of many small ones. + #[test] + #[cfg_attr(miri, ignore)] + fn test_batch_coalescing_reduces_size() { + use crate::writers::BufBatchWriter; + use arrow::array::Int32Array; + + // Create a wide schema to amplify per-block schema overhead + let fields: Vec = (0..20) + .map(|i| Field::new(format!("col_{i}"), DataType::Int32, false)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create many small batches (50 rows each) + let small_batches: Vec = (0..100) + .map(|batch_idx| { + let columns: Vec> = (0..20) + .map(|col_idx| { + let values: Vec = (0..50) + .map(|row| batch_idx * 50 + row + col_idx * 1000) + .collect(); + Arc::new(Int32Array::from(values)) as Arc + }) + .collect(); + RecordBatch::try_new(Arc::clone(&schema), columns).unwrap() + }) + .collect(); + + let codec = CompressionCodec::Lz4Frame; + let encode_time = Time::default(); + let write_time = Time::default(); + + // Write with coalescing (batch_size=8192) + let mut coalesced_output = Vec::new(); + { + let mut writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); + let mut buf_writer = BufBatchWriter::new( + &mut writer, + Cursor::new(&mut coalesced_output), + 1024 * 1024, + 8192, + ); + for batch in &small_batches { + buf_writer.write(batch, &encode_time, &write_time).unwrap(); + } + buf_writer.flush(&encode_time, &write_time).unwrap(); + } + + // Write without coalescing (batch_size=1) + let mut uncoalesced_output = Vec::new(); + { + let mut writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); + let mut buf_writer = BufBatchWriter::new( + &mut writer, + Cursor::new(&mut uncoalesced_output), + 1024 * 1024, + 1, + ); + for batch in &small_batches { + buf_writer.write(batch, &encode_time, &write_time).unwrap(); + } + buf_writer.flush(&encode_time, &write_time).unwrap(); + } + + // Coalesced output should be smaller due to fewer IPC schema blocks + assert!( + coalesced_output.len() < uncoalesced_output.len(), + "Coalesced output ({} bytes) should be smaller than uncoalesced ({} bytes)", + coalesced_output.len(), + uncoalesced_output.len() + ); + + // Verify both roundtrip correctly by reading all IPC blocks + let coalesced_rows = read_all_ipc_blocks(&coalesced_output); + let uncoalesced_rows = read_all_ipc_blocks(&uncoalesced_output); + assert_eq!( + coalesced_rows, 5000, + "Coalesced should contain all 5000 rows" + ); + assert_eq!( + uncoalesced_rows, 5000, + "Uncoalesced should contain all 5000 rows" + ); + } + + /// Read all IPC blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, + /// returning the total number of rows. + fn read_all_ipc_blocks(data: &[u8]) -> usize { + let mut offset = 0; + let mut total_rows = 0; + while offset < data.len() { + // First 8 bytes are the IPC length (little-endian u64) + let ipc_length = + u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; + // Skip the 8-byte length prefix; the next 8 bytes are field_count + codec header + let block_start = offset + 8; + let block_end = block_start + ipc_length; + // read_ipc_compressed expects data starting after the 16-byte header + // (i.e., after length + field_count), at the codec tag + let ipc_data = &data[block_start + 8..block_end]; + let batch = read_ipc_compressed(ipc_data).unwrap(); + total_rows += batch.num_rows(); + offset = block_end; + } + total_rows + } +} diff --git a/native/shuffle/src/spark_unsafe/list.rs b/native/shuffle/src/spark_unsafe/list.rs new file mode 100644 index 0000000000..d3cdf0dd04 --- /dev/null +++ b/native/shuffle/src/spark_unsafe/list.rs @@ -0,0 +1,485 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_comet_jni_bridge::errors::CometError; +use crate::spark_unsafe::{ + map::append_map_elements, + row::{ + append_field, downcast_builder_ref, impl_primitive_accessors, SparkUnsafeObject, + SparkUnsafeRow, + }, +}; +use arrow::array::{ + builder::{ + ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, + Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, + ListBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, + }, + MapBuilder, +}; +use arrow::datatypes::{DataType, TimeUnit}; + +/// Generates bulk append methods for primitive types in SparkUnsafeArray. +/// +/// # Safety invariants for all generated methods: +/// - `element_offset` points to contiguous element data of length `num_elements` +/// - `null_bitset_ptr()` returns a pointer to `ceil(num_elements/64)` i64 words +/// - These invariants are guaranteed by the SparkUnsafeArray layout from the JVM +macro_rules! impl_append_to_builder { + ($method_name:ident, $builder_type:ty, $element_type:ty) => { + pub(crate) fn $method_name(&self, builder: &mut $builder_type) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + if NULLABLE { + let mut ptr = self.element_offset as *const $element_type; + let null_words = self.null_bitset_ptr(); + debug_assert!(!null_words.is_null(), "null_bitset_ptr is null"); + debug_assert!(!ptr.is_null(), "element_offset pointer is null"); + for idx in 0..num_elements { + // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements + let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { ptr.read_unaligned() }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + // SAFETY: element_offset points to contiguous data of length num_elements + debug_assert!(self.element_offset != 0, "element_offset is null"); + let ptr = self.element_offset as *const $element_type; + // Use bulk copy when data is properly aligned, fall back to + // per-element unaligned reads otherwise + if (ptr as usize).is_multiple_of(std::mem::align_of::<$element_type>()) { + let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) }; + builder.append_slice(slice); + } else { + let mut ptr = ptr; + for _ in 0..num_elements { + builder.append_value(unsafe { ptr.read_unaligned() }); + ptr = unsafe { ptr.add(1) }; + } + } + } + } + }; +} + +pub struct SparkUnsafeArray { + row_addr: i64, + num_elements: usize, + element_offset: i64, +} + +impl SparkUnsafeObject for SparkUnsafeArray { + #[inline] + fn get_row_addr(&self) -> i64 { + self.row_addr + } + + #[inline] + fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 { + (self.element_offset + (index * element_size) as i64) as *const u8 + } + + // SparkUnsafeArray base address may be unaligned when nested within a row's variable-length + // region, so we must use ptr::read_unaligned() for all typed accesses. + impl_primitive_accessors!(read_unaligned); +} + +impl SparkUnsafeArray { + /// Creates a `SparkUnsafeArray` which points to the given address and size in bytes. + pub fn new(addr: i64) -> Self { + // SAFETY: addr points to valid Spark UnsafeArray data from the JVM. + // The first 8 bytes contain the element count as a little-endian i64. + debug_assert!(addr != 0, "SparkUnsafeArray::new: null address"); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; + let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); + + if num_elements < 0 { + panic!("Negative number of elements: {num_elements}"); + } + + if num_elements > i32::MAX as i64 { + panic!("Number of elements should <= i32::MAX: {num_elements}"); + } + + Self { + row_addr: addr, + num_elements: num_elements as usize, + element_offset: addr + Self::get_header_portion_in_bytes(num_elements), + } + } + + pub(crate) fn get_num_elements(&self) -> usize { + self.num_elements + } + + /// Returns the size of array header in bytes. + #[inline] + const fn get_header_portion_in_bytes(num_fields: i64) -> i64 { + 8 + ((num_fields + 63) / 64) * 8 + } + + /// Returns true if the null bit at the given index of the array is set. + #[inline] + pub(crate) fn is_null_at(&self, index: usize) -> bool { + // SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts + // at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures + // index < num_elements, so word_offset is within the bitset region. + debug_assert!( + index < self.num_elements, + "is_null_at: index {index} >= num_elements {}", + self.num_elements + ); + unsafe { + let mask: i64 = 1i64 << (index & 0x3f); + let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64; + let word: i64 = word_offset.read_unaligned(); + (word & mask) != 0 + } + } + + /// Returns the null bitset pointer (starts at row_addr + 8). + #[inline] + fn null_bitset_ptr(&self) -> *const i64 { + (self.row_addr + 8) as *const i64 + } + + /// Checks whether the null bit at `idx` is set in the given null bitset pointer. + /// + /// # Safety + /// `null_words` must point to at least `ceil((idx+1)/64)` i64 words. + #[inline] + unsafe fn is_null_in_bitset(null_words: *const i64, idx: usize) -> bool { + let word_idx = idx >> 6; + let bit_idx = idx & 0x3f; + (null_words.add(word_idx).read_unaligned() & (1i64 << bit_idx)) != 0 + } + + impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32); + impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64); + impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16); + impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8); + impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32); + impl_append_to_builder!(append_doubles_to_builder, Float64Builder, f64); + + /// Bulk append boolean values to builder. + /// Booleans are stored as 1 byte each in SparkUnsafeArray, requiring special handling. + pub(crate) fn append_booleans_to_builder( + &self, + builder: &mut BooleanBuilder, + ) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + let mut ptr = self.element_offset as *const u8; + debug_assert!( + !ptr.is_null(), + "append_booleans: element_offset pointer is null" + ); + + if NULLABLE { + let null_words = self.null_bitset_ptr(); + debug_assert!( + !null_words.is_null(), + "append_booleans: null_bitset_ptr is null" + ); + for idx in 0..num_elements { + // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements + let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr != 0 }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + for _ in 0..num_elements { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr != 0 }); + ptr = unsafe { ptr.add(1) }; + } + } + } + + /// Bulk append timestamp values to builder (stored as i64 microseconds). + pub(crate) fn append_timestamps_to_builder( + &self, + builder: &mut TimestampMicrosecondBuilder, + ) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + if NULLABLE { + let mut ptr = self.element_offset as *const i64; + let null_words = self.null_bitset_ptr(); + debug_assert!( + !null_words.is_null(), + "append_timestamps: null_bitset_ptr is null" + ); + debug_assert!( + !ptr.is_null(), + "append_timestamps: element_offset pointer is null" + ); + for idx in 0..num_elements { + // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements + let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { ptr.read_unaligned() }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + // SAFETY: element_offset points to contiguous i64 data of length num_elements + debug_assert!( + self.element_offset != 0, + "append_timestamps: element_offset is null" + ); + let ptr = self.element_offset as *const i64; + if (ptr as usize).is_multiple_of(std::mem::align_of::()) { + let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) }; + builder.append_slice(slice); + } else { + let mut ptr = ptr; + for _ in 0..num_elements { + builder.append_value(unsafe { ptr.read_unaligned() }); + ptr = unsafe { ptr.add(1) }; + } + } + } + } + + /// Bulk append date values to builder (stored as i32 days since epoch). + pub(crate) fn append_dates_to_builder( + &self, + builder: &mut Date32Builder, + ) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + if NULLABLE { + let mut ptr = self.element_offset as *const i32; + let null_words = self.null_bitset_ptr(); + debug_assert!( + !null_words.is_null(), + "append_dates: null_bitset_ptr is null" + ); + debug_assert!( + !ptr.is_null(), + "append_dates: element_offset pointer is null" + ); + for idx in 0..num_elements { + // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements + let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { ptr.read_unaligned() }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + // SAFETY: element_offset points to contiguous i32 data of length num_elements + debug_assert!( + self.element_offset != 0, + "append_dates: element_offset is null" + ); + let ptr = self.element_offset as *const i32; + if (ptr as usize).is_multiple_of(std::mem::align_of::()) { + let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) }; + builder.append_slice(slice); + } else { + let mut ptr = ptr; + for _ in 0..num_elements { + builder.append_value(unsafe { ptr.read_unaligned() }); + ptr = unsafe { ptr.add(1) }; + } + } + } + } +} + +pub fn append_to_builder( + data_type: &DataType, + builder: &mut dyn ArrayBuilder, + array: &SparkUnsafeArray, +) -> Result<(), CometError> { + macro_rules! add_values { + ($builder_type:ty, $add_value:expr, $add_null:expr) => { + let builder = downcast_builder_ref!($builder_type, builder); + for idx in 0..array.get_num_elements() { + if NULLABLE && array.is_null_at(idx) { + $add_null(builder); + } else { + $add_value(builder, array, idx); + } + } + }; + } + + match data_type { + DataType::Boolean => { + let builder = downcast_builder_ref!(BooleanBuilder, builder); + array.append_booleans_to_builder::(builder); + } + DataType::Int8 => { + let builder = downcast_builder_ref!(Int8Builder, builder); + array.append_bytes_to_builder::(builder); + } + DataType::Int16 => { + let builder = downcast_builder_ref!(Int16Builder, builder); + array.append_shorts_to_builder::(builder); + } + DataType::Int32 => { + let builder = downcast_builder_ref!(Int32Builder, builder); + array.append_ints_to_builder::(builder); + } + DataType::Int64 => { + let builder = downcast_builder_ref!(Int64Builder, builder); + array.append_longs_to_builder::(builder); + } + DataType::Float32 => { + let builder = downcast_builder_ref!(Float32Builder, builder); + array.append_floats_to_builder::(builder); + } + DataType::Float64 => { + let builder = downcast_builder_ref!(Float64Builder, builder); + array.append_doubles_to_builder::(builder); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let builder = downcast_builder_ref!(TimestampMicrosecondBuilder, builder); + array.append_timestamps_to_builder::(builder); + } + DataType::Date32 => { + let builder = downcast_builder_ref!(Date32Builder, builder); + array.append_dates_to_builder::(builder); + } + DataType::Binary => { + add_values!( + BinaryBuilder, + |builder: &mut BinaryBuilder, values: &SparkUnsafeArray, idx: usize| builder + .append_value(values.get_binary(idx)), + |builder: &mut BinaryBuilder| builder.append_null() + ); + } + DataType::Utf8 => { + add_values!( + StringBuilder, + |builder: &mut StringBuilder, values: &SparkUnsafeArray, idx: usize| builder + .append_value(values.get_string(idx)), + |builder: &mut StringBuilder| builder.append_null() + ); + } + DataType::List(field) => { + let builder = downcast_builder_ref!(ListBuilder>, builder); + for idx in 0..array.get_num_elements() { + if NULLABLE && array.is_null_at(idx) { + builder.append_null(); + } else { + let nested_array = array.get_array(idx); + append_list_element(field.data_type(), builder, &nested_array)?; + }; + } + } + DataType::Struct(fields) => { + let builder = downcast_builder_ref!(StructBuilder, builder); + for idx in 0..array.get_num_elements() { + let nested_row = if NULLABLE && array.is_null_at(idx) { + builder.append_null(); + SparkUnsafeRow::default() + } else { + builder.append(true); + array.get_struct(idx, fields.len()) + }; + + for (field_idx, field) in fields.into_iter().enumerate() { + append_field(field.data_type(), builder, &nested_row, field_idx)?; + } + } + } + DataType::Decimal128(p, _) => { + add_values!( + Decimal128Builder, + |builder: &mut Decimal128Builder, values: &SparkUnsafeArray, idx: usize| builder + .append_value(values.get_decimal(idx, *p)), + |builder: &mut Decimal128Builder| builder.append_null() + ); + } + DataType::Map(field, _) => { + let builder = downcast_builder_ref!( + MapBuilder, Box>, + builder + ); + for idx in 0..array.get_num_elements() { + if NULLABLE && array.is_null_at(idx) { + builder.append(false)?; + } else { + let nested_map = array.get_map(idx); + append_map_elements(field, builder, &nested_map)?; + }; + } + } + _ => { + return Err(CometError::Internal(format!( + "Unsupported map data type: {:?}", + data_type + ))) + } + } + + Ok(()) +} + +/// Appending the given list stored in `SparkUnsafeArray` into `ListBuilder`. +/// `element_dt` is the data type of the list element. `list_builder` is the list builder. +/// `list` is the list stored in `SparkUnsafeArray`. +pub fn append_list_element( + element_dt: &DataType, + list_builder: &mut ListBuilder>, + list: &SparkUnsafeArray, +) -> Result<(), CometError> { + append_to_builder::(element_dt, list_builder.values(), list)?; + list_builder.append(true); + + Ok(()) +} diff --git a/native/shuffle/src/spark_unsafe/map.rs b/native/shuffle/src/spark_unsafe/map.rs new file mode 100644 index 0000000000..efc3069a6e --- /dev/null +++ b/native/shuffle/src/spark_unsafe/map.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_comet_jni_bridge::errors::CometError; +use crate::spark_unsafe::list::{append_to_builder, SparkUnsafeArray}; +use arrow::array::builder::{ArrayBuilder, MapBuilder, MapFieldNames}; +use arrow::datatypes::{DataType, FieldRef}; + +pub struct SparkUnsafeMap { + pub(crate) keys: SparkUnsafeArray, + pub(crate) values: SparkUnsafeArray, +} + +impl SparkUnsafeMap { + /// Creates a `SparkUnsafeMap` which points to the given address and size in bytes. + pub(crate) fn new(addr: i64, size: i32) -> Self { + // SAFETY: addr points to valid Spark UnsafeMap data from the JVM. + // The first 8 bytes contain the key array size as a little-endian i64. + debug_assert!(addr != 0, "SparkUnsafeMap::new: null address"); + debug_assert!(size >= 0, "SparkUnsafeMap::new: negative size {size}"); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; + let key_array_size = i64::from_le_bytes(slice.try_into().unwrap()); + + if key_array_size < 0 { + panic!("Negative key size in bytes of map: {key_array_size}"); + } + + if key_array_size > i32::MAX as i64 { + panic!("Number of key size in bytes should <= i32::MAX: {key_array_size}"); + } + + let value_array_size = size - key_array_size as i32 - 8; + if value_array_size < 0 { + panic!("Negative value size in bytes of map: {value_array_size}"); + } + + let keys = SparkUnsafeArray::new(addr + 8); + let values = SparkUnsafeArray::new(addr + 8 + key_array_size); + + if keys.get_num_elements() != values.get_num_elements() { + panic!( + "Number of elements of keys and values should be the same: {} vs {}", + keys.get_num_elements(), + values.get_num_elements() + ); + } + + Self { keys, values } + } +} + +/// Appending the given map stored in `SparkUnsafeMap` into `MapBuilder`. +/// `field` includes data types of the map element. `map_builder` is the map builder. +/// `map` is the map stored in `SparkUnsafeMap`. +pub fn append_map_elements( + field: &FieldRef, + map_builder: &mut MapBuilder, Box>, + map: &SparkUnsafeMap, +) -> Result<(), CometError> { + let (key_field, value_field, _) = get_map_key_value_fields(field)?; + + let keys = &map.keys; + let values = &map.values; + + append_to_builder::(key_field.data_type(), map_builder.keys(), keys)?; + + append_to_builder::(value_field.data_type(), map_builder.values(), values)?; + + map_builder.append(true)?; + + Ok(()) +} + +#[allow(clippy::field_reassign_with_default)] +pub fn get_map_key_value_fields( + field: &FieldRef, +) -> Result<(&FieldRef, &FieldRef, MapFieldNames), CometError> { + let mut map_fieldnames = MapFieldNames::default(); + map_fieldnames.entry = field.name().to_string(); + + let (key_field, value_field) = match field.data_type() { + DataType::Struct(fields) => { + if fields.len() != 2 { + return Err(CometError::Internal(format!( + "Map field should have 2 fields, but got {}", + fields.len() + ))); + } + + let key = &fields[0]; + let value = &fields[1]; + + map_fieldnames.key = key.name().to_string(); + map_fieldnames.value = value.name().to_string(); + + (key, value) + } + _ => { + return Err(CometError::Internal(format!( + "Map field should be a struct, but got {:?}", + field.data_type() + ))); + } + }; + + Ok((key_field, value_field, map_fieldnames)) +} diff --git a/native/shuffle/src/spark_unsafe/mod.rs b/native/shuffle/src/spark_unsafe/mod.rs new file mode 100644 index 0000000000..6390a0f231 --- /dev/null +++ b/native/shuffle/src/spark_unsafe/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod list; +mod map; +pub mod row; diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs new file mode 100644 index 0000000000..13a80998db --- /dev/null +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -0,0 +1,1696 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utils for supporting native sort-based columnar shuffle. + +use datafusion_comet_jni_bridge::errors::CometError; +use datafusion_comet_common::bytes_to_i128; +use crate::codec::{Checksum, ShuffleBlockWriter}; +use crate::spark_unsafe::{ + list::{append_list_element, SparkUnsafeArray}, + map::{append_map_elements, get_map_key_value_fields, SparkUnsafeMap}, +}; +use arrow::array::{ + builder::{ + ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, + Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, + Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, StringDictionaryBuilder, + StructBuilder, TimestampMicrosecondBuilder, + }, + types::Int32Type, + Array, ArrayRef, RecordBatch, RecordBatchOptions, +}; +use arrow::compute::cast; +use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::error::ArrowError; +use datafusion::physical_plan::metrics::Time; +use jni::sys::{jint, jlong}; +use std::{ + fs::OpenOptions, + io::{Cursor, Write}, + str::from_utf8, + sync::Arc, +}; + +const MAX_LONG_DIGITS: u8 = 18; +const NESTED_TYPE_BUILDER_CAPACITY: usize = 100; + +/// A common trait for Spark Unsafe classes that can be used to access the underlying data, +/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to +/// access the underlying data with index. +/// +/// # Safety +/// +/// Implementations must ensure that: +/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory +/// - `get_element_offset()` returns a valid pointer within the row/array data region +/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format +/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership) +/// +/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are +/// safe to call as long as: +/// - The index is within bounds (caller's responsibility) +/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data +/// +/// # Alignment +/// +/// Primitive accessor methods are implemented separately for each type because they have +/// different alignment guarantees: +/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8, +/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`. +/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's +/// variable-length region, so accessors use `ptr::read_unaligned()`. +pub trait SparkUnsafeObject { + /// Returns the address of the row. + fn get_row_addr(&self) -> i64; + + /// Returns the offset of the element at the given index. + fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8; + + fn get_boolean(&self, index: usize) -> bool; + fn get_byte(&self, index: usize) -> i8; + fn get_short(&self, index: usize) -> i16; + fn get_int(&self, index: usize) -> i32; + fn get_long(&self, index: usize) -> i64; + fn get_float(&self, index: usize) -> f32; + fn get_double(&self, index: usize) -> f64; + fn get_date(&self, index: usize) -> i32; + fn get_timestamp(&self, index: usize) -> i64; + + /// Returns the offset and length of the element at the given index. + #[inline] + fn get_offset_and_len(&self, index: usize) -> (i32, i32) { + let offset_and_size = self.get_long(index); + let offset = (offset_and_size >> 32) as i32; + let len = offset_and_size as i32; + (offset, len) + } + + /// Returns string value at the given index of the object. + fn get_string(&self, index: usize) -> &str { + let (offset, len) = self.get_offset_and_len(index); + let addr = self.get_row_addr() + offset as i64; + // SAFETY: addr points to valid UTF-8 string data within the variable-length region. + // Offset and length are read from the fixed-length portion of the row/array. + debug_assert!(addr != 0, "get_string: null address at index {index}"); + debug_assert!( + len >= 0, + "get_string: negative length {len} at index {index}" + ); + let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; + + from_utf8(slice).unwrap() + } + + /// Returns binary value at the given index of the object. + fn get_binary(&self, index: usize) -> &[u8] { + let (offset, len) = self.get_offset_and_len(index); + let addr = self.get_row_addr() + offset as i64; + // SAFETY: addr points to valid binary data within the variable-length region. + // Offset and length are read from the fixed-length portion of the row/array. + debug_assert!(addr != 0, "get_binary: null address at index {index}"); + debug_assert!( + len >= 0, + "get_binary: negative length {len} at index {index}" + ); + unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } + } + + /// Returns decimal value at the given index of the object. + fn get_decimal(&self, index: usize, precision: u8) -> i128 { + if precision <= MAX_LONG_DIGITS { + self.get_long(index) as i128 + } else { + let slice = self.get_binary(index); + bytes_to_i128(slice) + } + } + + /// Returns struct value at the given index of the object. + fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow { + let (offset, len) = self.get_offset_and_len(index); + let mut row = SparkUnsafeRow::new_with_num_fields(num_fields); + row.point_to(self.get_row_addr() + offset as i64, len); + + row + } + + /// Returns array value at the given index of the object. + fn get_array(&self, index: usize) -> SparkUnsafeArray { + let (offset, _) = self.get_offset_and_len(index); + SparkUnsafeArray::new(self.get_row_addr() + offset as i64) + } + + fn get_map(&self, index: usize) -> SparkUnsafeMap { + let (offset, len) = self.get_offset_and_len(index); + SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len) + } +} + +/// Generates primitive accessor implementations for `SparkUnsafeObject`. +/// +/// Uses `$read_method` to read typed values from raw pointers: +/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned) +/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray) +macro_rules! impl_primitive_accessors { + ($read_method:ident) => { + #[inline] + fn get_boolean(&self, index: usize) -> bool { + let addr = self.get_element_offset(index, 1); + debug_assert!( + !addr.is_null(), + "get_boolean: null pointer at index {index}" + ); + // SAFETY: addr points to valid element data within the row/array region. + unsafe { *addr != 0 } + } + + #[inline] + fn get_byte(&self, index: usize) -> i8 { + let addr = self.get_element_offset(index, 1); + debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); + // SAFETY: addr points to valid element data (1 byte) within the row/array region. + unsafe { *(addr as *const i8) } + } + + #[inline] + fn get_short(&self, index: usize) -> i16 { + let addr = self.get_element_offset(index, 2) as *const i16; + debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); + // SAFETY: addr points to valid element data (2 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_int(&self, index: usize) -> i32 { + let addr = self.get_element_offset(index, 4) as *const i32; + debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_long(&self, index: usize) -> i64 { + let addr = self.get_element_offset(index, 8) as *const i64; + debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_float(&self, index: usize) -> f32 { + let addr = self.get_element_offset(index, 4) as *const f32; + debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_double(&self, index: usize) -> f64 { + let addr = self.get_element_offset(index, 8) as *const f64; + debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_date(&self, index: usize) -> i32 { + let addr = self.get_element_offset(index, 4) as *const i32; + debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + + #[inline] + fn get_timestamp(&self, index: usize) -> i64 { + let addr = self.get_element_offset(index, 8) as *const i64; + debug_assert!( + !addr.is_null(), + "get_timestamp: null pointer at index {index}" + ); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. + unsafe { addr.$read_method() } + } + }; +} +pub(crate) use impl_primitive_accessors; + +pub struct SparkUnsafeRow { + row_addr: i64, + row_size: i32, + row_bitset_width: i64, +} + +impl SparkUnsafeObject for SparkUnsafeRow { + fn get_row_addr(&self) -> i64 { + self.row_addr + } + + fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 { + let offset = self.row_bitset_width + (index * 8) as i64; + debug_assert!( + self.row_size >= 0 && offset + element_size as i64 <= self.row_size as i64, + "get_element_offset: access at offset {offset} with size {element_size} \ + exceeds row_size {} for index {index}", + self.row_size + ); + (self.row_addr + offset) as *const u8 + } + + // SparkUnsafeRow field offsets are always 8-byte aligned: the base address is 8-byte + // aligned (JVM guarantee), bitset_width is a multiple of 8, and each field slot is + // 8 bytes. This means we can safely use aligned ptr::read() for all typed accesses. + impl_primitive_accessors!(read); +} + +impl Default for SparkUnsafeRow { + fn default() -> Self { + Self { + row_addr: -1, + row_size: -1, + row_bitset_width: -1, + } + } +} + +impl SparkUnsafeRow { + fn new(schema: &[DataType]) -> Self { + Self { + row_addr: -1, + row_size: -1, + row_bitset_width: Self::get_row_bitset_width(schema.len()) as i64, + } + } + + /// Returns true if the row is a null row. + pub fn is_null_row(&self) -> bool { + self.row_addr == -1 && self.row_size == -1 && self.row_bitset_width == -1 + } + + /// Calculate the width of the bitset for the row in bytes. + /// The logic is from Spark `UnsafeRow.calculateBitSetWidthInBytes`. + #[inline] + pub const fn get_row_bitset_width(num_fields: usize) -> usize { + num_fields.div_ceil(64) * 8 + } + + pub fn new_with_num_fields(num_fields: usize) -> Self { + Self { + row_addr: -1, + row_size: -1, + row_bitset_width: Self::get_row_bitset_width(num_fields) as i64, + } + } + + /// Points the row to the given slice. + pub fn point_to_slice(&mut self, slice: &[u8]) { + self.row_addr = slice.as_ptr() as i64; + self.row_size = slice.len() as i32; + } + + /// Points the row to the given address with specified row size. + fn point_to(&mut self, row_addr: i64, row_size: i32) { + self.row_addr = row_addr; + self.row_size = row_size; + } + + pub fn get_row_size(&self) -> i32 { + self.row_size + } + + /// Returns true if the null bit at the given index of the row is set. + #[inline] + pub(crate) fn is_null_at(&self, index: usize) -> bool { + // SAFETY: row_addr points to valid Spark UnsafeRow data with at least + // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. + // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. + // The bitset starts at row_addr (8-byte aligned) and each word is at offset 8*k, + // so word_offset is always 8-byte aligned — we can use aligned ptr::read(). + debug_assert!(self.row_addr != -1, "is_null_at: row not initialized"); + unsafe { + let mask: i64 = 1i64 << (index & 0x3f); + let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; + let word: i64 = word_offset.read(); + (word & mask) != 0 + } + } + + /// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null). + pub fn set_not_null_at(&mut self, index: usize) { + // SAFETY: row_addr points to valid Spark UnsafeRow data with at least + // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. + // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. + // Writing is safe because we have mutable access and the memory is owned by the JVM. + // The bitset is always 8-byte aligned — we can use aligned ptr::read()/write(). + debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized"); + unsafe { + let mask: i64 = 1i64 << (index & 0x3f); + let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; + let word: i64 = word_offset.read(); + word_offset.write(word & !mask); + } + } +} + +macro_rules! downcast_builder_ref { + ($builder_type:ty, $builder:expr) => {{ + let actual_type_id = $builder.as_any().type_id(); + $builder + .as_any_mut() + .downcast_mut::<$builder_type>() + .ok_or_else(|| { + CometError::Internal(format!( + "Failed to downcast builder: expected {}, got {:?}", + stringify!($builder_type), + actual_type_id + )) + })? + }}; +} + +macro_rules! get_field_builder { + ($struct_builder:expr, $builder_type:ty, $idx:expr) => { + $struct_builder + .field_builder::<$builder_type>($idx) + .ok_or_else(|| { + CometError::Internal(format!( + "Failed to get field builder at index {}: expected {}", + $idx, + stringify!($builder_type) + )) + })? + }; +} + +// Expose the macro for other modules. +use crate::CompressionCodec; +pub(crate) use downcast_builder_ref; + +/// Appends field of row to the given struct builder. `dt` is the data type of the field. +/// `struct_builder` is the struct builder of the row. `row` is the row that contains the field. +/// `idx` is the index of the field in the row. The caller is responsible for ensuring that the +/// `struct_builder.append` is called before/after calling this function to append the null buffer +/// of the struct array. +#[allow(clippy::redundant_closure_call)] +pub(super) fn append_field( + dt: &DataType, + struct_builder: &mut StructBuilder, + row: &SparkUnsafeRow, + idx: usize, +) -> Result<(), CometError> { + /// A macro for generating code of appending value into field builder of Arrow struct builder. + macro_rules! append_field_to_builder { + ($builder_type:ty, $accessor:expr) => {{ + let field_builder = get_field_builder!(struct_builder, $builder_type, idx); + + if row.is_null_row() { + // The row is null. + field_builder.append_null(); + } else { + let is_null = row.is_null_at(idx); + + if is_null { + // The field in the row is null. + // Append a null value to the field builder. + field_builder.append_null(); + } else { + $accessor(field_builder); + } + } + }}; + } + + match dt { + DataType::Boolean => { + append_field_to_builder!(BooleanBuilder, |builder: &mut BooleanBuilder| builder + .append_value(row.get_boolean(idx))); + } + DataType::Int8 => { + append_field_to_builder!(Int8Builder, |builder: &mut Int8Builder| builder + .append_value(row.get_byte(idx))); + } + DataType::Int16 => { + append_field_to_builder!(Int16Builder, |builder: &mut Int16Builder| builder + .append_value(row.get_short(idx))); + } + DataType::Int32 => { + append_field_to_builder!(Int32Builder, |builder: &mut Int32Builder| builder + .append_value(row.get_int(idx))); + } + DataType::Int64 => { + append_field_to_builder!(Int64Builder, |builder: &mut Int64Builder| builder + .append_value(row.get_long(idx))); + } + DataType::Float32 => { + append_field_to_builder!(Float32Builder, |builder: &mut Float32Builder| builder + .append_value(row.get_float(idx))); + } + DataType::Float64 => { + append_field_to_builder!(Float64Builder, |builder: &mut Float64Builder| builder + .append_value(row.get_double(idx))); + } + DataType::Date32 => { + append_field_to_builder!(Date32Builder, |builder: &mut Date32Builder| builder + .append_value(row.get_date(idx))); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_field_to_builder!( + TimestampMicrosecondBuilder, + |builder: &mut TimestampMicrosecondBuilder| builder + .append_value(row.get_timestamp(idx)) + ); + } + DataType::Binary => { + append_field_to_builder!(BinaryBuilder, |builder: &mut BinaryBuilder| builder + .append_value(row.get_binary(idx))); + } + DataType::Utf8 => { + append_field_to_builder!(StringBuilder, |builder: &mut StringBuilder| builder + .append_value(row.get_string(idx))); + } + DataType::Decimal128(p, _) => { + append_field_to_builder!(Decimal128Builder, |builder: &mut Decimal128Builder| builder + .append_value(row.get_decimal(idx, *p))); + } + DataType::Struct(fields) => { + // Appending value into struct field builder of Arrow struct builder. + let field_builder = get_field_builder!(struct_builder, StructBuilder, idx); + + let nested_row = if row.is_null_row() || row.is_null_at(idx) { + // The row is null, or the field in the row is null, i.e., a null nested row. + // Append a null value to the row builder. + field_builder.append_null(); + SparkUnsafeRow::default() + } else { + field_builder.append(true); + row.get_struct(idx, fields.len()) + }; + + for (field_idx, field) in fields.into_iter().enumerate() { + append_field(field.data_type(), field_builder, &nested_row, field_idx)?; + } + } + DataType::Map(field, _) => { + let field_builder = get_field_builder!( + struct_builder, + MapBuilder, Box>, + idx + ); + + if row.is_null_row() { + // The row is null. + field_builder.append(false)?; + } else { + let is_null = row.is_null_at(idx); + + if is_null { + // The field in the row is null. + // Append a null value to the map builder. + field_builder.append(false)?; + } else { + append_map_elements(field, field_builder, &row.get_map(idx))?; + } + } + } + DataType::List(field) => { + let field_builder = + get_field_builder!(struct_builder, ListBuilder>, idx); + + if row.is_null_row() { + // The row is null. + field_builder.append_null(); + } else { + let is_null = row.is_null_at(idx); + + if is_null { + // The field in the row is null. + // Append a null value to the list builder. + field_builder.append_null(); + } else { + append_list_element(field.data_type(), field_builder, &row.get_array(idx))? + } + } + } + _ => { + unreachable!("Unsupported data type of struct field: {:?}", dt) + } + } + + Ok(()) +} + +/// Appends nested struct fields to the struct builder using field-major order. +/// This is a helper function for processing nested struct fields recursively. +/// +/// Unlike `append_struct_fields_field_major`, this function takes slices of row addresses, +/// sizes, and null flags directly, without needing to navigate from a parent row. +#[allow(clippy::redundant_closure_call)] +fn append_nested_struct_fields_field_major( + row_addresses: &[jlong], + row_sizes: &[jint], + struct_is_null: &[bool], + struct_builder: &mut StructBuilder, + fields: &arrow::datatypes::Fields, +) -> Result<(), CometError> { + let num_rows = row_addresses.len(); + let mut row = SparkUnsafeRow::new_with_num_fields(fields.len()); + + // Helper macro for processing primitive fields + macro_rules! process_field { + ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{ + let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + // Struct is null, field is also null + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at($field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value($get_value(&row, $field_idx)); + } + } + } + }}; + } + + // Process each field across all rows + for (field_idx, field) in fields.iter().enumerate() { + match field.data_type() { + DataType::Boolean => { + process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_boolean(idx)); + } + DataType::Int8 => { + process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_byte(idx)); + } + DataType::Int16 => { + process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_short(idx)); + } + DataType::Int32 => { + process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_int(idx)); + } + DataType::Int64 => { + process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_long(idx)); + } + DataType::Float32 => { + process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_float(idx)); + } + DataType::Float64 => { + process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_double(idx)); + } + DataType::Date32 => { + process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_date(idx)); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + process_field!( + TimestampMicrosecondBuilder, + field_idx, + |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) + ); + } + DataType::Binary => { + let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(row.get_binary(field_idx)); + } + } + } + } + DataType::Utf8 => { + let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(row.get_string(field_idx)); + } + } + } + } + DataType::Decimal128(p, _) => { + let p = *p; + let field_builder = + get_field_builder!(struct_builder, Decimal128Builder, field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(row.get_decimal(field_idx, p)); + } + } + } + } + DataType::Struct(nested_fields) => { + let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx); + + // Collect nested struct addresses and sizes in one pass, building validity + let mut nested_addresses: Vec = Vec::with_capacity(num_rows); + let mut nested_sizes: Vec = Vec::with_capacity(num_rows); + let mut nested_is_null: Vec = Vec::with_capacity(num_rows); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + // Parent struct is null, nested struct is also null + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + nested_builder.append(true); + nested_is_null.push(false); + // Get nested struct address and size + let nested_row = row.get_struct(field_idx, nested_fields.len()); + nested_addresses.push(nested_row.get_row_addr()); + nested_sizes.push(nested_row.get_row_size()); + } + } + } + + // Recursively process nested struct fields in field-major order + append_nested_struct_fields_field_major( + &nested_addresses, + &nested_sizes, + &nested_is_null, + nested_builder, + nested_fields, + )?; + } + // For list and map, fall back to append_field since they have variable-length elements + dt @ (DataType::List(_) | DataType::Map(_, _)) => { + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + let null_row = SparkUnsafeRow::default(); + append_field(dt, struct_builder, &null_row, field_idx)?; + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + append_field(dt, struct_builder, &row, field_idx)?; + } + } + } + _ => { + unreachable!( + "Unsupported data type of struct field: {:?}", + field.data_type() + ) + } + } + } + + Ok(()) +} + +/// Reads row address and size from JVM-provided pointer arrays and points the row to that data. +/// +/// # Safety +/// Caller must ensure row_addresses_ptr and row_sizes_ptr are valid for index i. +/// This is guaranteed when called from append_columns with indices in [row_start, row_end). +macro_rules! read_row_at { + ($row:expr, $row_addresses_ptr:expr, $row_sizes_ptr:expr, $i:expr) => {{ + // SAFETY: Caller guarantees pointers are valid for this index (see macro doc) + debug_assert!( + !$row_addresses_ptr.is_null(), + "read_row_at: null row_addresses_ptr" + ); + debug_assert!(!$row_sizes_ptr.is_null(), "read_row_at: null row_sizes_ptr"); + let row_addr = unsafe { *$row_addresses_ptr.add($i) }; + let row_size = unsafe { *$row_sizes_ptr.add($i) }; + $row.point_to(row_addr, row_size); + }}; +} + +/// Appends a batch of list values to the list builder with a single type dispatch. +/// This moves type dispatch from O(rows) to O(1), significantly improving performance +/// for large batches. +#[allow(clippy::too_many_arguments)] +fn append_list_column_batch( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + schema: &[DataType], + column_idx: usize, + element_type: &DataType, + list_builder: &mut ListBuilder>, +) -> Result<(), CometError> { + let mut row = SparkUnsafeRow::new(schema); + + // Helper macro for primitive element types - gets builder fresh each iteration + // to avoid borrow conflicts with list_builder.append() + macro_rules! process_primitive_lists { + ($builder_type:ty, $append_fn:ident) => {{ + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + list_builder.append_null(); + } else { + let array = row.get_array(column_idx); + // Get values builder fresh each iteration to avoid borrow conflict + let values_builder = list_builder + .values() + .as_any_mut() + .downcast_mut::<$builder_type>() + .expect(stringify!($builder_type)); + array.$append_fn::(values_builder); + list_builder.append(true); + } + } + }}; + } + + match element_type { + DataType::Boolean => { + process_primitive_lists!(BooleanBuilder, append_booleans_to_builder); + } + DataType::Int8 => { + process_primitive_lists!(Int8Builder, append_bytes_to_builder); + } + DataType::Int16 => { + process_primitive_lists!(Int16Builder, append_shorts_to_builder); + } + DataType::Int32 => { + process_primitive_lists!(Int32Builder, append_ints_to_builder); + } + DataType::Int64 => { + process_primitive_lists!(Int64Builder, append_longs_to_builder); + } + DataType::Float32 => { + process_primitive_lists!(Float32Builder, append_floats_to_builder); + } + DataType::Float64 => { + process_primitive_lists!(Float64Builder, append_doubles_to_builder); + } + DataType::Date32 => { + process_primitive_lists!(Date32Builder, append_dates_to_builder); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + process_primitive_lists!(TimestampMicrosecondBuilder, append_timestamps_to_builder); + } + // For complex element types, fall back to per-row dispatch + _ => { + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + list_builder.append_null(); + } else { + append_list_element(element_type, list_builder, &row.get_array(column_idx))?; + } + } + } + } + + Ok(()) +} + +/// Appends a batch of map values to the map builder with a single type dispatch. +/// This moves type dispatch from O(rows × 2) to O(2), improving performance for maps. +#[allow(clippy::too_many_arguments)] +fn append_map_column_batch( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + schema: &[DataType], + column_idx: usize, + field: &arrow::datatypes::FieldRef, + map_builder: &mut MapBuilder, Box>, +) -> Result<(), CometError> { + let mut row = SparkUnsafeRow::new(schema); + let (key_field, value_field, _) = get_map_key_value_fields(field)?; + let key_type = key_field.data_type(); + let value_type = value_field.data_type(); + + // Helper macro for processing maps with primitive key/value types + // Uses scoped borrows to avoid borrow checker conflicts + macro_rules! process_primitive_maps { + ($key_builder:ty, $key_append:ident, $val_builder:ty, $val_append:ident) => {{ + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + map_builder.append(false)?; + } else { + let map = row.get_map(column_idx); + // Process keys in a scope so borrow ends + { + let keys_builder = map_builder + .keys() + .as_any_mut() + .downcast_mut::<$key_builder>() + .expect(stringify!($key_builder)); + map.keys.$key_append::(keys_builder); + } + // Process values in a scope so borrow ends + { + let values_builder = map_builder + .values() + .as_any_mut() + .downcast_mut::<$val_builder>() + .expect(stringify!($val_builder)); + map.values.$val_append::(values_builder); + } + map_builder.append(true)?; + } + } + }}; + } + + // Optimize common map type combinations + match (key_type, value_type) { + // Map + (DataType::Int64, DataType::Int64) => { + process_primitive_maps!( + Int64Builder, + append_longs_to_builder, + Int64Builder, + append_longs_to_builder + ); + } + // Map + (DataType::Int64, DataType::Float64) => { + process_primitive_maps!( + Int64Builder, + append_longs_to_builder, + Float64Builder, + append_doubles_to_builder + ); + } + // Map + (DataType::Int32, DataType::Int32) => { + process_primitive_maps!( + Int32Builder, + append_ints_to_builder, + Int32Builder, + append_ints_to_builder + ); + } + // Map + (DataType::Int32, DataType::Int64) => { + process_primitive_maps!( + Int32Builder, + append_ints_to_builder, + Int64Builder, + append_longs_to_builder + ); + } + // For other types, fall back to per-row dispatch + _ => { + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + map_builder.append(false)?; + } else { + append_map_elements(field, map_builder, &row.get_map(column_idx))?; + } + } + } + } + + Ok(()) +} + +/// Appends struct fields to the struct builder using field-major order. +/// This processes one field at a time across all rows, which moves type dispatch +/// outside the row loop (O(fields) dispatches instead of O(rows × fields)). +#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] +fn append_struct_fields_field_major( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + parent_row: &mut SparkUnsafeRow, + column_idx: usize, + struct_builder: &mut StructBuilder, + fields: &arrow::datatypes::Fields, +) -> Result<(), CometError> { + let num_rows = row_end - row_start; + let num_fields = fields.len(); + + // First pass: Build struct validity and collect which structs are null + // We use a Vec for simplicity; could use a bitset for better memory + let mut struct_is_null = Vec::with_capacity(num_rows); + + for i in row_start..row_end { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + + let is_null = parent_row.is_null_at(column_idx); + struct_is_null.push(is_null); + + if is_null { + struct_builder.append_null(); + } else { + struct_builder.append(true); + } + } + + // Helper macro for processing primitive fields + macro_rules! process_field { + ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{ + let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + // Struct is null, field is also null + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at($field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value($get_value(&nested_row, $field_idx)); + } + } + } + }}; + } + + // Second pass: Process each field across all rows + for (field_idx, field) in fields.iter().enumerate() { + match field.data_type() { + DataType::Boolean => { + process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_boolean(idx)); + } + DataType::Int8 => { + process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_byte(idx)); + } + DataType::Int16 => { + process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_short(idx)); + } + DataType::Int32 => { + process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_int(idx)); + } + DataType::Int64 => { + process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_long(idx)); + } + DataType::Float32 => { + process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_float(idx)); + } + DataType::Float64 => { + process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_double(idx)); + } + DataType::Date32 => { + process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_date(idx)); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + process_field!( + TimestampMicrosecondBuilder, + field_idx, + |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) + ); + } + DataType::Binary => { + let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(nested_row.get_binary(field_idx)); + } + } + } + } + DataType::Utf8 => { + let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(nested_row.get_string(field_idx)); + } + } + } + } + DataType::Decimal128(p, _) => { + let p = *p; + let field_builder = + get_field_builder!(struct_builder, Decimal128Builder, field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(nested_row.get_decimal(field_idx, p)); + } + } + } + } + // For nested structs, apply field-major processing recursively + DataType::Struct(nested_fields) => { + let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx); + + // Collect nested struct addresses and sizes in one pass, building validity + let mut nested_addresses: Vec = Vec::with_capacity(num_rows); + let mut nested_sizes: Vec = Vec::with_capacity(num_rows); + let mut nested_is_null: Vec = Vec::with_capacity(num_rows); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + // Parent struct is null, nested struct is also null + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let parent_struct = parent_row.get_struct(column_idx, num_fields); + + if parent_struct.is_null_at(field_idx) { + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + nested_builder.append(true); + nested_is_null.push(false); + // Get nested struct address and size + let nested_row = + parent_struct.get_struct(field_idx, nested_fields.len()); + nested_addresses.push(nested_row.get_row_addr()); + nested_sizes.push(nested_row.get_row_size()); + } + } + } + + // Recursively process nested struct fields in field-major order + append_nested_struct_fields_field_major( + &nested_addresses, + &nested_sizes, + &nested_is_null, + nested_builder, + nested_fields, + )?; + } + // For list and map, fall back to append_field since they have variable-length elements + dt @ (DataType::List(_) | DataType::Map(_, _)) => { + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + let null_row = SparkUnsafeRow::default(); + append_field(dt, struct_builder, &null_row, field_idx)?; + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + append_field(dt, struct_builder, &nested_row, field_idx)?; + } + } + } + _ => { + unreachable!( + "Unsupported data type of struct field: {:?}", + field.data_type() + ) + } + } + } + + Ok(()) +} + +/// Appends column of top rows to the given array builder. +/// +/// # Safety +/// +/// The caller must ensure: +/// - `row_addresses_ptr` points to an array of at least `row_end` jlong values +/// - `row_sizes_ptr` points to an array of at least `row_end` jint values +/// - Each address in `row_addresses_ptr[row_start..row_end]` points to valid Spark UnsafeRow data +/// - The memory remains valid for the duration of this function call +/// +/// These invariants are guaranteed when called from JNI with arrays provided by the JVM. +#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] +fn append_columns( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + schema: &[DataType], + column_idx: usize, + builder: &mut Box, + prefer_dictionary_ratio: f64, +) -> Result<(), CometError> { + /// A macro for generating code of appending values into Arrow array builders. + macro_rules! append_column_to_builder { + ($builder_type:ty, $accessor:expr) => {{ + let element_builder = builder + .as_any_mut() + .downcast_mut::<$builder_type>() + .expect(stringify!($builder_type)); + let mut row = SparkUnsafeRow::new(schema); + + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + let is_null = row.is_null_at(column_idx); + + if is_null { + // The element value is null. + // Append a null value to the element builder. + element_builder.append_null(); + } else { + $accessor(element_builder, &row, column_idx); + } + } + }}; + } + + let dt = &schema[column_idx]; + + match dt { + DataType::Boolean => { + append_column_to_builder!( + BooleanBuilder, + |builder: &mut BooleanBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_boolean(idx)) + ); + } + DataType::Int8 => { + append_column_to_builder!( + Int8Builder, + |builder: &mut Int8Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_byte(idx)) + ); + } + DataType::Int16 => { + append_column_to_builder!( + Int16Builder, + |builder: &mut Int16Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_short(idx)) + ); + } + DataType::Int32 => { + append_column_to_builder!( + Int32Builder, + |builder: &mut Int32Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_int(idx)) + ); + } + DataType::Int64 => { + append_column_to_builder!( + Int64Builder, + |builder: &mut Int64Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_long(idx)) + ); + } + DataType::Float32 => { + append_column_to_builder!( + Float32Builder, + |builder: &mut Float32Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_float(idx)) + ); + } + DataType::Float64 => { + append_column_to_builder!( + Float64Builder, + |builder: &mut Float64Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_double(idx)) + ); + } + DataType::Decimal128(p, _) => { + append_column_to_builder!( + Decimal128Builder, + |builder: &mut Decimal128Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_decimal(idx, *p)) + ); + } + DataType::Utf8 => { + if prefer_dictionary_ratio > 1.0 { + append_column_to_builder!( + StringDictionaryBuilder, + |builder: &mut StringDictionaryBuilder, + row: &SparkUnsafeRow, + idx| builder.append_value(row.get_string(idx)) + ); + } else { + append_column_to_builder!( + StringBuilder, + |builder: &mut StringBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_string(idx)) + ); + } + } + DataType::Binary => { + if prefer_dictionary_ratio > 1.0 { + append_column_to_builder!( + BinaryDictionaryBuilder, + |builder: &mut BinaryDictionaryBuilder, + row: &SparkUnsafeRow, + idx| builder.append_value(row.get_binary(idx)) + ); + } else { + append_column_to_builder!( + BinaryBuilder, + |builder: &mut BinaryBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_binary(idx)) + ); + } + } + DataType::Date32 => { + append_column_to_builder!( + Date32Builder, + |builder: &mut Date32Builder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_date(idx)) + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + append_column_to_builder!( + TimestampMicrosecondBuilder, + |builder: &mut TimestampMicrosecondBuilder, row: &SparkUnsafeRow, idx| builder + .append_value(row.get_timestamp(idx)) + ); + } + DataType::Map(field, _) => { + let map_builder = downcast_builder_ref!( + MapBuilder, Box>, + builder + ); + // Use batched processing for better performance + append_map_column_batch( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + schema, + column_idx, + field, + map_builder, + )?; + } + DataType::List(field) => { + let list_builder = downcast_builder_ref!(ListBuilder>, builder); + // Use batched processing for better performance + append_list_column_batch( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + schema, + column_idx, + field.data_type(), + list_builder, + )?; + } + DataType::Struct(fields) => { + let struct_builder = builder + .as_any_mut() + .downcast_mut::() + .expect("StructBuilder"); + let mut row = SparkUnsafeRow::new(schema); + + // Use field-major processing to avoid per-row type dispatch + append_struct_fields_field_major( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + &mut row, + column_idx, + struct_builder, + fields, + )?; + } + _ => { + unreachable!("Unsupported data type of column: {:?}", dt) + } + } + + Ok(()) +} + +fn make_builders( + dt: &DataType, + row_num: usize, + prefer_dictionary_ratio: f64, +) -> Result, CometError> { + let builder: Box = match dt { + DataType::Boolean => Box::new(BooleanBuilder::with_capacity(row_num)), + DataType::Int8 => Box::new(Int8Builder::with_capacity(row_num)), + DataType::Int16 => Box::new(Int16Builder::with_capacity(row_num)), + DataType::Int32 => Box::new(Int32Builder::with_capacity(row_num)), + DataType::Int64 => Box::new(Int64Builder::with_capacity(row_num)), + DataType::Float32 => Box::new(Float32Builder::with_capacity(row_num)), + DataType::Float64 => Box::new(Float64Builder::with_capacity(row_num)), + DataType::Decimal128(_, _) => { + Box::new(Decimal128Builder::with_capacity(row_num).with_data_type(dt.clone())) + } + DataType::Utf8 => { + if prefer_dictionary_ratio > 1.0 { + Box::new(StringDictionaryBuilder::::with_capacity( + row_num / 2, + row_num, + 1024, + )) + } else { + Box::new(StringBuilder::with_capacity(row_num, 1024)) + } + } + DataType::Binary => { + if prefer_dictionary_ratio > 1.0 { + Box::new(BinaryDictionaryBuilder::::with_capacity( + row_num / 2, + row_num, + 1024, + )) + } else { + Box::new(BinaryBuilder::with_capacity(row_num, 1024)) + } + } + DataType::Date32 => Box::new(Date32Builder::with_capacity(row_num)), + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone())) + } + DataType::Map(field, _) => { + let (key_field, value_field, map_field_names) = get_map_key_value_fields(field)?; + let key_dt = key_field.data_type(); + let value_dt = value_field.data_type(); + let key_builder = make_builders(key_dt, NESTED_TYPE_BUILDER_CAPACITY, 1.0)?; + let value_builder = make_builders(value_dt, NESTED_TYPE_BUILDER_CAPACITY, 1.0)?; + + Box::new( + MapBuilder::new(Some(map_field_names), key_builder, value_builder) + .with_values_field(Arc::clone(value_field)), + ) + } + DataType::List(field) => { + // Disable dictionary encoding for array element + let value_builder = + make_builders(field.data_type(), NESTED_TYPE_BUILDER_CAPACITY, 1.0)?; + + // Needed to overwrite default ListBuilder creation having the incoming field schema to be driving + let value_field = Arc::clone(field); + + Box::new(ListBuilder::new(value_builder).with_field(value_field)) + } + DataType::Struct(fields) => { + let field_builders = fields + .iter() + // Disable dictionary encoding for struct fields + .map(|field| make_builders(field.data_type(), row_num, 1.0)) + .collect::, _>>()?; + + Box::new(StructBuilder::new(fields.clone(), field_builders)) + } + _ => return Err(CometError::Internal(format!("Unsupported type: {dt:?}"))), + }; + + Ok(builder) +} + +/// Processes a sorted row partition and writes the result to the given output path. +#[allow(clippy::too_many_arguments)] +pub fn process_sorted_row_partition( + row_num: usize, + batch_size: usize, + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + schema: &[DataType], + output_path: String, + prefer_dictionary_ratio: f64, + checksum_enabled: bool, + checksum_algo: i32, + // This is the checksum value passed in from Spark side, and is getting updated for + // each shuffle partition Spark processes. It is called "initial" here to indicate + // this is the initial checksum for this method, as it also gets updated iteratively + // inside the loop within the method across batches. + initial_checksum: Option, + codec: &CompressionCodec, +) -> Result<(i64, Option), CometError> { + // The current row number we are reading + let mut current_row = 0; + // Total number of bytes written + let mut written = 0; + // The current checksum value. This is updated incrementally in the following loop. + let mut current_checksum = if checksum_enabled { + Some(Checksum::try_new(checksum_algo, initial_checksum)?) + } else { + None + }; + + // Create builders once and reuse them across batches. + // After finish() is called, builders are reset and can be reused. + let mut data_builders: Vec> = vec![]; + schema.iter().try_for_each(|dt| { + make_builders(dt, batch_size, prefer_dictionary_ratio) + .map(|builder| data_builders.push(builder))?; + Ok::<(), CometError>(()) + })?; + + // Open the output file once and reuse it across batches + let mut output_data = OpenOptions::new() + .create(true) + .append(true) + .open(&output_path)?; + + // Reusable buffer for serialized batch data + let mut frozen: Vec = Vec::new(); + + while current_row < row_num { + let n = std::cmp::min(batch_size, row_num - current_row); + + // Appends rows to the array builders. + // For each column, iterating over rows and appending values to corresponding array + // builder. + for (idx, builder) in data_builders.iter_mut().enumerate() { + append_columns( + row_addresses_ptr, + row_sizes_ptr, + current_row, + current_row + n, + schema, + idx, + builder, + prefer_dictionary_ratio, + )?; + } + + // Writes a record batch generated from the array builders to the output file. + // Note: builder_to_array calls finish() which resets the builder, making it reusable for the next batch. + let array_refs: Result, _> = data_builders + .iter_mut() + .zip(schema.iter()) + .map(|(builder, datatype)| builder_to_array(builder, datatype, prefer_dictionary_ratio)) + .collect(); + let batch = make_batch(array_refs?, n)?; + + frozen.clear(); + let mut cursor = Cursor::new(&mut frozen); + + // we do not collect metrics in Native_writeSortedFileNative + let ipc_time = Time::default(); + let block_writer = ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone())?; + written += block_writer.write_batch(&batch, &mut cursor, &ipc_time)?; + + if let Some(checksum) = &mut current_checksum { + checksum.update(&mut cursor)?; + } + + output_data.write_all(&frozen)?; + current_row += n; + } + + Ok((written as i64, current_checksum.map(|c| c.finalize()))) +} + +fn builder_to_array( + builder: &mut Box, + datatype: &DataType, + prefer_dictionary_ratio: f64, +) -> Result { + match datatype { + // We don't have redundant dictionary values which are not referenced by any key. + // So the reasonable ratio must be larger than 1.0. + DataType::Utf8 if prefer_dictionary_ratio > 1.0 => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("StringDictionaryBuilder"); + + let dict_array = builder.finish(); + let num_keys = dict_array.keys().len(); + let num_values = dict_array.values().len(); + + if num_keys as f64 > num_values as f64 * prefer_dictionary_ratio { + // The number of keys in the dictionary is less than a ratio of the number of + // values. The dictionary is efficient, so we return it directly. + Ok(Arc::new(dict_array)) + } else { + // If the dictionary is not efficient, we convert it to a plain string array. + Ok(cast(&dict_array, &DataType::Utf8)?) + } + } + DataType::Binary if prefer_dictionary_ratio > 1.0 => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("BinaryDictionaryBuilder"); + + let dict_array = builder.finish(); + let num_keys = dict_array.keys().len(); + let num_values = dict_array.values().len(); + + if num_keys as f64 > num_values as f64 * prefer_dictionary_ratio { + // The number of keys in the dictionary is less than a ratio of the number of + // values. The dictionary is efficient, so we return it directly. + Ok(Arc::new(dict_array)) + } else { + // If the dictionary is not efficient, we convert it to a plain string array. + Ok(cast(&dict_array, &DataType::Binary)?) + } + } + _ => Ok(builder.finish()), + } +} + +fn make_batch(arrays: Vec, row_count: usize) -> Result { + let fields = arrays + .iter() + .enumerate() + .map(|(i, array)| Field::new(format!("c{i}"), array.data_type().clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + let options = RecordBatchOptions::new().with_row_count(Option::from(row_count)); + RecordBatch::try_new_with_options(schema, arrays, &options) +} + +#[cfg(test)] +mod test { + use arrow::datatypes::Fields; + + use super::*; + + #[test] + fn test_append_null_row_to_struct_builder() { + let data_type = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Boolean, true), + ])); + let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]); + let mut struct_builder = StructBuilder::from_fields(fields, 1); + let row = SparkUnsafeRow::default(); + append_field(&data_type, &mut struct_builder, &row, 0).expect("append field"); + struct_builder.append_null(); + let struct_array = struct_builder.finish(); + assert_eq!(struct_array.len(), 1); + assert!(struct_array.is_null(0)); + } + + #[test] + fn test_append_null_struct_field_to_struct_builder() { + let data_type = DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Boolean, true), + ])); + let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]); + let mut struct_builder = StructBuilder::from_fields(fields, 1); + let mut row = SparkUnsafeRow::new_with_num_fields(1); + // 8 bytes null bitset + 8 bytes field value = 16 bytes + // Set bit 0 in the null bitset to mark field 0 as null + // Use aligned buffer to match real Spark UnsafeRow layout (8-byte aligned) + #[repr(align(8))] + struct Aligned([u8; 16]); + let mut data = Aligned([0u8; 16]); + data.0[0] = 1; + row.point_to_slice(&data.0); + append_field(&data_type, &mut struct_builder, &row, 0).expect("append field"); + struct_builder.append_null(); + let struct_array = struct_builder.finish(); + assert_eq!(struct_array.len(), 1); + assert!(struct_array.is_null(0)); + } +} diff --git a/native/shuffle/src/writers/buf_batch_writer.rs b/native/shuffle/src/writers/buf_batch_writer.rs new file mode 100644 index 0000000000..6344a8e5f2 --- /dev/null +++ b/native/shuffle/src/writers/buf_batch_writer.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ShuffleBlockWriter; +use arrow::array::RecordBatch; +use arrow::compute::kernels::coalesce::BatchCoalescer; +use datafusion::physical_plan::metrics::Time; +use std::borrow::Borrow; +use std::io::{Cursor, Seek, SeekFrom, Write}; + +/// Write batches to writer while using a buffer to avoid frequent system calls. +/// The record batches were first written by ShuffleBlockWriter into an internal buffer. +/// Once the buffer exceeds the max size, the buffer will be flushed to the writer. +/// +/// Small batches are coalesced using Arrow's [`BatchCoalescer`] before serialization, +/// producing exactly `batch_size`-row output batches to reduce per-block IPC schema overhead. +/// The coalescer is lazily initialized on the first write. +pub(crate) struct BufBatchWriter, W: Write> { + shuffle_block_writer: S, + writer: W, + buffer: Vec, + buffer_max_size: usize, + /// Coalesces small batches into target_batch_size before serialization. + /// Lazily initialized on first write to capture the schema. + coalescer: Option, + /// Target batch size for coalescing + batch_size: usize, +} + +impl, W: Write> BufBatchWriter { + pub(crate) fn new( + shuffle_block_writer: S, + writer: W, + buffer_max_size: usize, + batch_size: usize, + ) -> Self { + Self { + shuffle_block_writer, + writer, + buffer: vec![], + buffer_max_size, + coalescer: None, + batch_size, + } + } + + pub(crate) fn write( + &mut self, + batch: &RecordBatch, + encode_time: &Time, + write_time: &Time, + ) -> datafusion::common::Result { + let coalescer = self + .coalescer + .get_or_insert_with(|| BatchCoalescer::new(batch.schema(), self.batch_size)); + coalescer.push_batch(batch.clone())?; + + // Drain completed batches into a local vec so the coalescer borrow ends + // before we call write_batch_to_buffer (which borrows &mut self). + let mut completed = Vec::new(); + while let Some(batch) = coalescer.next_completed_batch() { + completed.push(batch); + } + + let mut bytes_written = 0; + for batch in &completed { + bytes_written += self.write_batch_to_buffer(batch, encode_time, write_time)?; + } + Ok(bytes_written) + } + + /// Serialize a single batch into the byte buffer, flushing to the writer if needed. + fn write_batch_to_buffer( + &mut self, + batch: &RecordBatch, + encode_time: &Time, + write_time: &Time, + ) -> datafusion::common::Result { + let mut cursor = Cursor::new(&mut self.buffer); + cursor.seek(SeekFrom::End(0))?; + let bytes_written = + self.shuffle_block_writer + .borrow() + .write_batch(batch, &mut cursor, encode_time)?; + let pos = cursor.position(); + if pos >= self.buffer_max_size as u64 { + let mut write_timer = write_time.timer(); + self.writer.write_all(&self.buffer)?; + write_timer.stop(); + self.buffer.clear(); + } + Ok(bytes_written) + } + + pub(crate) fn flush( + &mut self, + encode_time: &Time, + write_time: &Time, + ) -> datafusion::common::Result<()> { + // Finish any remaining buffered rows in the coalescer + let mut remaining = Vec::new(); + if let Some(coalescer) = &mut self.coalescer { + coalescer.finish_buffered_batch()?; + while let Some(batch) = coalescer.next_completed_batch() { + remaining.push(batch); + } + } + for batch in &remaining { + self.write_batch_to_buffer(batch, encode_time, write_time)?; + } + + // Flush the byte buffer to the underlying writer + let mut write_timer = write_time.timer(); + if !self.buffer.is_empty() { + self.writer.write_all(&self.buffer)?; + } + self.writer.flush()?; + write_timer.stop(); + self.buffer.clear(); + Ok(()) + } +} + +impl, W: Write + Seek> BufBatchWriter { + pub(crate) fn writer_stream_position(&mut self) -> datafusion::common::Result { + self.writer.stream_position().map_err(Into::into) + } +} diff --git a/native/shuffle/src/writers/mod.rs b/native/shuffle/src/writers/mod.rs new file mode 100644 index 0000000000..b58989e46c --- /dev/null +++ b/native/shuffle/src/writers/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod buf_batch_writer; +mod partition_writer; + +pub(crate) use buf_batch_writer::BufBatchWriter; +pub(crate) use partition_writer::PartitionWriter; diff --git a/native/shuffle/src/writers/partition_writer.rs b/native/shuffle/src/writers/partition_writer.rs new file mode 100644 index 0000000000..48017871db --- /dev/null +++ b/native/shuffle/src/writers/partition_writer.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::metrics::ShufflePartitionerMetrics; +use crate::partitioners::PartitionedBatchIterator; +use crate::writers::buf_batch_writer::BufBatchWriter; +use crate::ShuffleBlockWriter; +use datafusion::common::DataFusionError; +use datafusion::execution::disk_manager::RefCountedTempFile; +use datafusion::execution::runtime_env::RuntimeEnv; +use std::fs::{File, OpenOptions}; + +struct SpillFile { + temp_file: RefCountedTempFile, + file: File, +} + +pub(crate) struct PartitionWriter { + /// Spill file for intermediate shuffle output for this partition. Each spill event + /// will append to this file and the contents will be copied to the shuffle file at + /// the end of processing. + spill_file: Option, + /// Writer that performs encoding and compression + shuffle_block_writer: ShuffleBlockWriter, +} + +impl PartitionWriter { + pub(crate) fn try_new( + shuffle_block_writer: ShuffleBlockWriter, + ) -> datafusion::common::Result { + Ok(Self { + spill_file: None, + shuffle_block_writer, + }) + } + + fn ensure_spill_file_created( + &mut self, + runtime: &RuntimeEnv, + ) -> datafusion::common::Result<()> { + if self.spill_file.is_none() { + // Spill file is not yet created, create it + let spill_file = runtime + .disk_manager + .create_tmp_file("shuffle writer spill")?; + let spill_data = OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(spill_file.path()) + .map_err(|e| { + DataFusionError::Execution(format!("Error occurred while spilling {e}")) + })?; + self.spill_file = Some(SpillFile { + temp_file: spill_file, + file: spill_data, + }); + } + Ok(()) + } + + pub(crate) fn spill( + &mut self, + iter: &mut PartitionedBatchIterator, + runtime: &RuntimeEnv, + metrics: &ShufflePartitionerMetrics, + write_buffer_size: usize, + batch_size: usize, + ) -> datafusion::common::Result { + if let Some(batch) = iter.next() { + self.ensure_spill_file_created(runtime)?; + + let total_bytes_written = { + let mut buf_batch_writer = BufBatchWriter::new( + &mut self.shuffle_block_writer, + &mut self.spill_file.as_mut().unwrap().file, + write_buffer_size, + batch_size, + ); + let mut bytes_written = + buf_batch_writer.write(&batch?, &metrics.encode_time, &metrics.write_time)?; + for batch in iter { + let batch = batch?; + bytes_written += buf_batch_writer.write( + &batch, + &metrics.encode_time, + &metrics.write_time, + )?; + } + buf_batch_writer.flush(&metrics.encode_time, &metrics.write_time)?; + bytes_written + }; + + Ok(total_bytes_written) + } else { + Ok(0) + } + } + + pub(crate) fn path(&self) -> Option<&std::path::Path> { + self.spill_file + .as_ref() + .map(|spill_file| spill_file.temp_file.path()) + } + + #[cfg(test)] + pub(crate) fn has_spill_file(&self) -> bool { + self.spill_file.is_some() + } +} From b75f8566bd5a49bec930256e128a0e9a72ef980e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:27:30 -0600 Subject: [PATCH 5/9] feat: update core to depend on shuffle crate Replace the local shuffle module in core with a re-export of the new datafusion-comet-shuffle crate, removing the now-redundant source files. --- native/Cargo.lock | 1 + native/core/Cargo.toml | 1 + native/core/src/execution/mod.rs | 2 +- native/core/src/execution/shuffle/codec.rs | 239 --- .../execution/shuffle/comet_partitioning.rs | 71 - native/core/src/execution/shuffle/metrics.rs | 61 - native/core/src/execution/shuffle/mod.rs | 28 - .../src/execution/shuffle/partitioners/mod.rs | 35 - .../shuffle/partitioners/multi_partition.rs | 642 ------- .../partitioned_batch_iterator.rs | 110 -- .../shuffle/partitioners/single_partition.rs | 192 -- .../src/execution/shuffle/shuffle_writer.rs | 696 ------- .../execution/shuffle/spark_unsafe/list.rs | 487 ----- .../src/execution/shuffle/spark_unsafe/map.rs | 123 -- .../src/execution/shuffle/spark_unsafe/mod.rs | 20 - .../src/execution/shuffle/spark_unsafe/row.rs | 1702 ----------------- .../shuffle/writers/buf_batch_writer.rs | 142 -- .../core/src/execution/shuffle/writers/mod.rs | 22 - .../shuffle/writers/partition_writer.rs | 124 -- 19 files changed, 3 insertions(+), 4695 deletions(-) delete mode 100644 native/core/src/execution/shuffle/codec.rs delete mode 100644 native/core/src/execution/shuffle/comet_partitioning.rs delete mode 100644 native/core/src/execution/shuffle/metrics.rs delete mode 100644 native/core/src/execution/shuffle/mod.rs delete mode 100644 native/core/src/execution/shuffle/partitioners/mod.rs delete mode 100644 native/core/src/execution/shuffle/partitioners/multi_partition.rs delete mode 100644 native/core/src/execution/shuffle/partitioners/partitioned_batch_iterator.rs delete mode 100644 native/core/src/execution/shuffle/partitioners/single_partition.rs delete mode 100644 native/core/src/execution/shuffle/shuffle_writer.rs delete mode 100644 native/core/src/execution/shuffle/spark_unsafe/list.rs delete mode 100644 native/core/src/execution/shuffle/spark_unsafe/map.rs delete mode 100644 native/core/src/execution/shuffle/spark_unsafe/mod.rs delete mode 100644 native/core/src/execution/shuffle/spark_unsafe/row.rs delete mode 100644 native/core/src/execution/shuffle/writers/buf_batch_writer.rs delete mode 100644 native/core/src/execution/shuffle/writers/mod.rs delete mode 100644 native/core/src/execution/shuffle/writers/partition_writer.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index 2df7677088..60798c8458 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1841,6 +1841,7 @@ dependencies = [ "datafusion-comet-jni-bridge", "datafusion-comet-objectstore-hdfs", "datafusion-comet-proto", + "datafusion-comet-shuffle", "datafusion-comet-spark-expr", "datafusion-datasource", "datafusion-functions-nested", diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 1da8bed207..256bf39e20 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -68,6 +68,7 @@ datafusion-comet-common = { workspace = true } datafusion-comet-spark-expr = { workspace = true } datafusion-comet-jni-bridge = { workspace = true } datafusion-comet-proto = { workspace = true } +datafusion-comet-shuffle = { workspace = true } object_store = { workspace = true } url = { workspace = true } aws-config = { workspace = true } diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index 85fc672461..f556fce41c 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -23,7 +23,7 @@ pub(crate) mod metrics; pub mod operators; pub(crate) mod planner; pub mod serde; -pub mod shuffle; +pub use datafusion_comet_shuffle as shuffle; pub(crate) mod sort; pub(crate) mod spark_plan; pub use datafusion_comet_spark_expr::timezone; diff --git a/native/core/src/execution/shuffle/codec.rs b/native/core/src/execution/shuffle/codec.rs deleted file mode 100644 index 33e6989d4c..0000000000 --- a/native/core/src/execution/shuffle/codec.rs +++ /dev/null @@ -1,239 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::errors::{CometError, CometResult}; -use arrow::array::RecordBatch; -use arrow::datatypes::Schema; -use arrow::ipc::reader::StreamReader; -use arrow::ipc::writer::StreamWriter; -use bytes::Buf; -use crc32fast::Hasher; -use datafusion::common::DataFusionError; -use datafusion::error::Result; -use datafusion::physical_plan::metrics::Time; -use simd_adler32::Adler32; -use std::io::{Cursor, Seek, SeekFrom, Write}; - -#[derive(Debug, Clone)] -pub enum CompressionCodec { - None, - Lz4Frame, - Zstd(i32), - Snappy, -} - -#[derive(Clone)] -pub struct ShuffleBlockWriter { - codec: CompressionCodec, - header_bytes: Vec, -} - -impl ShuffleBlockWriter { - pub fn try_new(schema: &Schema, codec: CompressionCodec) -> Result { - let header_bytes = Vec::with_capacity(20); - let mut cursor = Cursor::new(header_bytes); - - // leave space for compressed message length - cursor.seek_relative(8)?; - - // write number of columns because JVM side needs to know how many addresses to allocate - let field_count = schema.fields().len(); - cursor.write_all(&field_count.to_le_bytes())?; - - // write compression codec to header - let codec_header = match &codec { - CompressionCodec::Snappy => b"SNAP", - CompressionCodec::Lz4Frame => b"LZ4_", - CompressionCodec::Zstd(_) => b"ZSTD", - CompressionCodec::None => b"NONE", - }; - cursor.write_all(codec_header)?; - - let header_bytes = cursor.into_inner(); - - Ok(Self { - codec, - header_bytes, - }) - } - - /// Writes given record batch as Arrow IPC bytes into given writer. - /// Returns number of bytes written. - pub fn write_batch( - &self, - batch: &RecordBatch, - output: &mut W, - ipc_time: &Time, - ) -> Result { - if batch.num_rows() == 0 { - return Ok(0); - } - - let mut timer = ipc_time.timer(); - let start_pos = output.stream_position()?; - - // write header - output.write_all(&self.header_bytes)?; - - let output = match &self.codec { - CompressionCodec::None => { - let mut arrow_writer = StreamWriter::try_new(output, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - arrow_writer.into_inner()? - } - CompressionCodec::Lz4Frame => { - let mut wtr = lz4_flex::frame::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.finish().map_err(|e| { - DataFusionError::Execution(format!("lz4 compression error: {e}")) - })? - } - - CompressionCodec::Zstd(level) => { - let encoder = zstd::Encoder::new(output, *level)?; - let mut arrow_writer = StreamWriter::try_new(encoder, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - let zstd_encoder = arrow_writer.into_inner()?; - zstd_encoder.finish()? - } - - CompressionCodec::Snappy => { - let mut wtr = snap::write::FrameEncoder::new(output); - let mut arrow_writer = StreamWriter::try_new(&mut wtr, &batch.schema())?; - arrow_writer.write(batch)?; - arrow_writer.finish()?; - wtr.into_inner().map_err(|e| { - DataFusionError::Execution(format!("snappy compression error: {e}")) - })? - } - }; - - // fill ipc length - let end_pos = output.stream_position()?; - let ipc_length = end_pos - start_pos - 8; - let max_size = i32::MAX as u64; - if ipc_length > max_size { - return Err(DataFusionError::Execution(format!( - "Shuffle block size {ipc_length} exceeds maximum size of {max_size}. \ - Try reducing batch size or increasing compression level" - ))); - } - - // fill ipc length - output.seek(SeekFrom::Start(start_pos))?; - output.write_all(&ipc_length.to_le_bytes())?; - output.seek(SeekFrom::Start(end_pos))?; - - timer.stop(); - - Ok((end_pos - start_pos) as usize) - } -} - -pub fn read_ipc_compressed(bytes: &[u8]) -> Result { - match &bytes[0..4] { - b"SNAP" => { - let decoder = snap::read::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"LZ4_" => { - let decoder = lz4_flex::frame::FrameDecoder::new(&bytes[4..]); - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"ZSTD" => { - let decoder = zstd::Decoder::new(&bytes[4..])?; - let mut reader = - unsafe { StreamReader::try_new(decoder, None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - b"NONE" => { - let mut reader = - unsafe { StreamReader::try_new(&bytes[4..], None)?.with_skip_validation(true) }; - reader.next().unwrap().map_err(|e| e.into()) - } - other => Err(DataFusionError::Execution(format!( - "Failed to decode batch: invalid compression codec: {other:?}" - ))), - } -} - -/// Checksum algorithms for writing IPC bytes. -#[derive(Clone)] -pub(crate) enum Checksum { - /// CRC32 checksum algorithm. - CRC32(Hasher), - /// Adler32 checksum algorithm. - Adler32(Adler32), -} - -impl Checksum { - pub(crate) fn try_new(algo: i32, initial_opt: Option) -> CometResult { - match algo { - 0 => { - let hasher = if let Some(initial) = initial_opt { - Hasher::new_with_initial(initial) - } else { - Hasher::new() - }; - Ok(Checksum::CRC32(hasher)) - } - 1 => { - let hasher = if let Some(initial) = initial_opt { - // Note that Adler32 initial state is not zero. - // i.e., `Adler32::from_checksum(0)` is not the same as `Adler32::new()`. - Adler32::from_checksum(initial) - } else { - Adler32::new() - }; - Ok(Checksum::Adler32(hasher)) - } - _ => Err(CometError::Internal( - "Unsupported checksum algorithm".to_string(), - )), - } - } - - pub(crate) fn update(&mut self, cursor: &mut Cursor<&mut Vec>) -> CometResult<()> { - match self { - Checksum::CRC32(hasher) => { - std::io::Seek::seek(cursor, SeekFrom::Start(0))?; - hasher.update(cursor.chunk()); - Ok(()) - } - Checksum::Adler32(hasher) => { - std::io::Seek::seek(cursor, SeekFrom::Start(0))?; - hasher.write(cursor.chunk()); - Ok(()) - } - } - } - - pub(crate) fn finalize(self) -> u32 { - match self { - Checksum::CRC32(hasher) => hasher.finalize(), - Checksum::Adler32(hasher) => hasher.finish(), - } - } -} diff --git a/native/core/src/execution/shuffle/comet_partitioning.rs b/native/core/src/execution/shuffle/comet_partitioning.rs deleted file mode 100644 index b8d68cd21e..0000000000 --- a/native/core/src/execution/shuffle/comet_partitioning.rs +++ /dev/null @@ -1,71 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::row::{OwnedRow, RowConverter}; -use datafusion::physical_expr::{LexOrdering, PhysicalExpr}; -use std::sync::Arc; - -#[derive(Debug, Clone)] -pub enum CometPartitioning { - SinglePartition, - /// Allocate rows based on a hash of one of more expressions and the specified number of - /// partitions. Args are 1) the expression to hash on, and 2) the number of partitions. - Hash(Vec>, usize), - /// Allocate rows based on the lexical order of one of more expressions and the specified number of - /// partitions. Args are 1) the LexOrdering to use to compare values and split into partitions, - /// 2) the number of partitions, 3) the RowConverter used to view incoming RecordBatches as Arrow - /// Rows for comparing to 4) OwnedRows that represent the boundaries of each partition, used with - /// LexOrdering to bin each value in the RecordBatch to a partition. - RangePartitioning(LexOrdering, usize, Arc, Vec), - /// Round robin partitioning. Distributes rows across partitions by sorting them by hash - /// (computed from columns) and then assigning partitions sequentially. Args are: - /// 1) number of partitions, 2) max columns to hash (0 means no limit). - RoundRobin(usize, usize), -} - -impl CometPartitioning { - pub fn partition_count(&self) -> usize { - use CometPartitioning::*; - match self { - SinglePartition => 1, - Hash(_, n) | RangePartitioning(_, n, _, _) | RoundRobin(n, _) => *n, - } - } -} - -pub(super) fn pmod(hash: u32, n: usize) -> usize { - let hash = hash as i32; - let n = n as i32; - let r = hash % n; - let result = if r < 0 { (r + n) % n } else { r }; - result as usize -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_pmod() { - let i: Vec = vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb]; - let result = i.into_iter().map(|i| pmod(i, 200)).collect::>(); - - // expected partition from Spark with n=200 - let expected = vec![69, 5, 193, 171, 115]; - assert_eq!(result, expected); - } -} diff --git a/native/core/src/execution/shuffle/metrics.rs b/native/core/src/execution/shuffle/metrics.rs deleted file mode 100644 index 33b51c3cd8..0000000000 --- a/native/core/src/execution/shuffle/metrics.rs +++ /dev/null @@ -1,61 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::physical_plan::metrics::{ - BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, Time, -}; - -pub(super) struct ShufflePartitionerMetrics { - /// metrics - pub(super) baseline: BaselineMetrics, - - /// Time to perform repartitioning - pub(super) repart_time: Time, - - /// Time encoding batches to IPC format - pub(super) encode_time: Time, - - /// Time spent writing to disk. Maps to "shuffleWriteTime" in Spark SQL Metrics. - pub(super) write_time: Time, - - /// Number of input batches - pub(super) input_batches: Count, - - /// count of spills during the execution of the operator - pub(super) spill_count: Count, - - /// total spilled bytes during the execution of the operator - pub(super) spilled_bytes: Count, - - /// The original size of spilled data. Different to `spilled_bytes` because of compression. - pub(super) data_size: Count, -} - -impl ShufflePartitionerMetrics { - pub(super) fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { - Self { - baseline: BaselineMetrics::new(metrics, partition), - repart_time: MetricBuilder::new(metrics).subset_time("repart_time", partition), - encode_time: MetricBuilder::new(metrics).subset_time("encode_time", partition), - write_time: MetricBuilder::new(metrics).subset_time("write_time", partition), - input_batches: MetricBuilder::new(metrics).counter("input_batches", partition), - spill_count: MetricBuilder::new(metrics).spill_count(partition), - spilled_bytes: MetricBuilder::new(metrics).spilled_bytes(partition), - data_size: MetricBuilder::new(metrics).counter("data_size", partition), - } - } -} diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs deleted file mode 100644 index 6018cff50f..0000000000 --- a/native/core/src/execution/shuffle/mod.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -pub(crate) mod codec; -mod comet_partitioning; -mod metrics; -mod partitioners; -mod shuffle_writer; -pub mod spark_unsafe; -mod writers; - -pub use codec::{read_ipc_compressed, CompressionCodec, ShuffleBlockWriter}; -pub use comet_partitioning::CometPartitioning; -pub use shuffle_writer::ShuffleWriterExec; diff --git a/native/core/src/execution/shuffle/partitioners/mod.rs b/native/core/src/execution/shuffle/partitioners/mod.rs deleted file mode 100644 index b9058f66f4..0000000000 --- a/native/core/src/execution/shuffle/partitioners/mod.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod multi_partition; -mod partitioned_batch_iterator; -mod single_partition; - -use arrow::record_batch::RecordBatch; -use datafusion::common::Result; - -pub(super) use multi_partition::MultiPartitionShuffleRepartitioner; -pub(super) use partitioned_batch_iterator::PartitionedBatchIterator; -pub(super) use single_partition::SinglePartitionShufflePartitioner; - -#[async_trait::async_trait] -pub(super) trait ShufflePartitioner: Send + Sync { - /// Insert a batch into the partitioner - async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()>; - /// Write shuffle data and shuffle index file to disk - fn shuffle_write(&mut self) -> Result<()>; -} diff --git a/native/core/src/execution/shuffle/partitioners/multi_partition.rs b/native/core/src/execution/shuffle/partitioners/multi_partition.rs deleted file mode 100644 index 9c366ad462..0000000000 --- a/native/core/src/execution/shuffle/partitioners/multi_partition.rs +++ /dev/null @@ -1,642 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::execution::shuffle::metrics::ShufflePartitionerMetrics; -use crate::execution::shuffle::partitioners::partitioned_batch_iterator::{ - PartitionedBatchIterator, PartitionedBatchesProducer, -}; -use crate::execution::shuffle::partitioners::ShufflePartitioner; -use crate::execution::shuffle::writers::{BufBatchWriter, PartitionWriter}; -use crate::execution::shuffle::{ - comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter, -}; -use crate::execution::tracing::{with_trace, with_trace_async}; -use arrow::array::{ArrayRef, RecordBatch}; -use arrow::datatypes::SchemaRef; -use datafusion::common::utils::proxy::VecAllocExt; -use datafusion::common::DataFusionError; -use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; -use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::physical_plan::metrics::Time; -use datafusion_comet_spark_expr::murmur3::create_murmur3_hashes; -use itertools::Itertools; -use std::fmt; -use std::fmt::{Debug, Formatter}; -use std::fs::{File, OpenOptions}; -use std::io::{BufReader, BufWriter, Seek, Write}; -use std::sync::Arc; -use tokio::time::Instant; - -#[derive(Default)] -struct ScratchSpace { - /// Hashes for each row in the current batch. - hashes_buf: Vec, - /// Partition ids for each row in the current batch. - partition_ids: Vec, - /// The row indices of the rows in each partition. This array is conceptually divided into - /// partitions, where each partition contains the row indices of the rows in that partition. - /// The length of this array is the same as the number of rows in the batch. - partition_row_indices: Vec, - /// The start indices of partitions in partition_row_indices. partition_starts[K] and - /// partition_starts[K + 1] are the start and end indices of partition K in partition_row_indices. - /// The length of this array is 1 + the number of partitions. - partition_starts: Vec, -} - -impl ScratchSpace { - fn map_partition_ids_to_starts_and_indices( - &mut self, - num_output_partitions: usize, - num_rows: usize, - ) { - let partition_ids = &mut self.partition_ids[..num_rows]; - - // count each partition size, while leaving the last extra element as 0 - let partition_counters = &mut self.partition_starts; - partition_counters.resize(num_output_partitions + 1, 0); - partition_counters.fill(0); - partition_ids - .iter() - .for_each(|partition_id| partition_counters[*partition_id as usize] += 1); - - // accumulate partition counters into partition ends - // e.g. partition counter: [1, 3, 2, 1, 0] => [1, 4, 6, 7, 7] - let partition_ends = partition_counters; - let mut accum = 0; - partition_ends.iter_mut().for_each(|v| { - *v += accum; - accum = *v; - }); - - // calculate partition row indices and partition starts - // e.g. partition ids: [3, 1, 1, 1, 2, 2, 0] will produce the following partition_row_indices - // and partition_starts arrays: - // - // partition_row_indices: [6, 1, 2, 3, 4, 5, 0] - // partition_starts: [0, 1, 4, 6, 7] - // - // partition_starts conceptually splits partition_row_indices into smaller slices. - // Each slice partition_row_indices[partition_starts[K]..partition_starts[K + 1]] contains the - // row indices of the input batch that are partitioned into partition K. For example, - // first partition 0 has one row index [6], partition 1 has row indices [1, 2, 3], etc. - let partition_row_indices = &mut self.partition_row_indices; - partition_row_indices.resize(num_rows, 0); - for (index, partition_id) in partition_ids.iter().enumerate().rev() { - partition_ends[*partition_id as usize] -= 1; - let end = partition_ends[*partition_id as usize]; - partition_row_indices[end as usize] = index as u32; - } - - // after calculating, partition ends become partition starts - } -} - -/// A partitioner that uses a hash function to partition data into multiple partitions -pub(crate) struct MultiPartitionShuffleRepartitioner { - output_data_file: String, - output_index_file: String, - buffered_batches: Vec, - partition_indices: Vec>, - partition_writers: Vec, - shuffle_block_writer: ShuffleBlockWriter, - /// Partitioning scheme to use - partitioning: CometPartitioning, - runtime: Arc, - metrics: ShufflePartitionerMetrics, - /// Reused scratch space for computing partition indices - scratch: ScratchSpace, - /// The configured batch size - batch_size: usize, - /// Reservation for repartitioning - reservation: MemoryReservation, - tracing_enabled: bool, - /// Size of the write buffer in bytes - write_buffer_size: usize, -} - -impl MultiPartitionShuffleRepartitioner { - #[allow(clippy::too_many_arguments)] - pub(crate) fn try_new( - partition: usize, - output_data_file: String, - output_index_file: String, - schema: SchemaRef, - partitioning: CometPartitioning, - metrics: ShufflePartitionerMetrics, - runtime: Arc, - batch_size: usize, - codec: CompressionCodec, - tracing_enabled: bool, - write_buffer_size: usize, - ) -> datafusion::common::Result { - let num_output_partitions = partitioning.partition_count(); - assert_ne!( - num_output_partitions, 1, - "Use SinglePartitionShufflePartitioner for 1 output partition." - ); - - // Vectors in the scratch space will be filled with valid values before being used, this - // initialization code is simply initializing the vectors to the desired size. - // The initial values are not used. - let scratch = ScratchSpace { - hashes_buf: match partitioning { - // Allocate hashes_buf for hash and round robin partitioning. - // Round robin hashes all columns to achieve even, deterministic distribution. - CometPartitioning::Hash(_, _) | CometPartitioning::RoundRobin(_, _) => { - vec![0; batch_size] - } - _ => vec![], - }, - partition_ids: vec![0; batch_size], - partition_row_indices: vec![0; batch_size], - partition_starts: vec![0; num_output_partitions + 1], - }; - - let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; - - let partition_writers = (0..num_output_partitions) - .map(|_| PartitionWriter::try_new(shuffle_block_writer.clone())) - .collect::>>()?; - - let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{partition}]")) - .with_can_spill(true) - .register(&runtime.memory_pool); - - Ok(Self { - output_data_file, - output_index_file, - buffered_batches: vec![], - partition_indices: vec![vec![]; num_output_partitions], - partition_writers, - shuffle_block_writer, - partitioning, - runtime, - metrics, - scratch, - batch_size, - reservation, - tracing_enabled, - write_buffer_size, - }) - } - - /// Shuffles rows in input batch into corresponding partition buffer. - /// This function first calculates hashes for rows and then takes rows in same - /// partition as a record batch which is appended into partition buffer. - /// This should not be called directly. Use `insert_batch` instead. - async fn partitioning_batch(&mut self, input: RecordBatch) -> datafusion::common::Result<()> { - if input.num_rows() == 0 { - // skip empty batch - return Ok(()); - } - - if input.num_rows() > self.batch_size { - return Err(DataFusionError::Internal( - "Input batch size exceeds configured batch size. Call `insert_batch` instead." - .to_string(), - )); - } - - // Update data size metric - self.metrics.data_size.add(input.get_array_memory_size()); - - // NOTE: in shuffle writer exec, the output_rows metrics represents the - // number of rows those are written to output data file. - self.metrics.baseline.record_output(input.num_rows()); - - match &self.partitioning { - CometPartitioning::Hash(exprs, num_output_partitions) => { - let mut scratch = std::mem::take(&mut self.scratch); - let (partition_starts, partition_row_indices): (&Vec, &Vec) = { - let mut timer = self.metrics.repart_time.timer(); - - // Evaluate partition expressions to get rows to apply partitioning scheme. - let arrays = exprs - .iter() - .map(|expr| expr.evaluate(&input)?.into_array(input.num_rows())) - .collect::>>()?; - - let num_rows = arrays[0].len(); - - // Use identical seed as Spark hash partitioning. - let hashes_buf = &mut scratch.hashes_buf[..num_rows]; - hashes_buf.fill(42_u32); - - // Generate partition ids for every row. - { - // Hash arrays and compute partition ids based on number of partitions. - let partition_ids = &mut scratch.partition_ids[..num_rows]; - create_murmur3_hashes(&arrays, hashes_buf)? - .iter() - .enumerate() - .for_each(|(idx, hash)| { - partition_ids[idx] = - comet_partitioning::pmod(*hash, *num_output_partitions) as u32; - }); - } - - // We now have partition ids for every input row, map that to partition starts - // and partition indices to eventually right these rows to partition buffers. - scratch - .map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows); - - timer.stop(); - Ok::<(&Vec, &Vec), DataFusionError>(( - &scratch.partition_starts, - &scratch.partition_row_indices, - )) - }?; - - self.buffer_partitioned_batch_may_spill( - input, - partition_row_indices, - partition_starts, - ) - .await?; - self.scratch = scratch; - } - CometPartitioning::RangePartitioning( - lex_ordering, - num_output_partitions, - row_converter, - bounds, - ) => { - let mut scratch = std::mem::take(&mut self.scratch); - let (partition_starts, partition_row_indices): (&Vec, &Vec) = { - let mut timer = self.metrics.repart_time.timer(); - - // Evaluate partition expressions for values to apply partitioning scheme on. - let arrays = lex_ordering - .iter() - .map(|expr| expr.expr.evaluate(&input)?.into_array(input.num_rows())) - .collect::>>()?; - - let num_rows = arrays[0].len(); - - // Generate partition ids for every row, first by converting the partition - // arrays to Rows, and then doing binary search for each Row against the - // bounds Rows. - { - let row_batch = row_converter.convert_columns(arrays.as_slice())?; - let partition_ids = &mut scratch.partition_ids[..num_rows]; - - row_batch.iter().enumerate().for_each(|(row_idx, row)| { - partition_ids[row_idx] = bounds - .as_slice() - .partition_point(|bound| bound.row() <= row) - as u32 - }); - } - - // We now have partition ids for every input row, map that to partition starts - // and partition indices to eventually right these rows to partition buffers. - scratch - .map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows); - - timer.stop(); - Ok::<(&Vec, &Vec), DataFusionError>(( - &scratch.partition_starts, - &scratch.partition_row_indices, - )) - }?; - - self.buffer_partitioned_batch_may_spill( - input, - partition_row_indices, - partition_starts, - ) - .await?; - self.scratch = scratch; - } - CometPartitioning::RoundRobin(num_output_partitions, max_hash_columns) => { - // Comet implements "round robin" as hash partitioning on columns. - // This achieves the same goal as Spark's round robin (even distribution - // without semantic grouping) while being deterministic for fault tolerance. - // - // Note: This produces different partition assignments than Spark's round robin, - // which sorts by UnsafeRow binary representation before assigning partitions. - // However, both approaches provide even distribution and determinism. - let mut scratch = std::mem::take(&mut self.scratch); - let (partition_starts, partition_row_indices): (&Vec, &Vec) = { - let mut timer = self.metrics.repart_time.timer(); - - let num_rows = input.num_rows(); - - // Collect columns for hashing, respecting max_hash_columns limit - // max_hash_columns of 0 means no limit (hash all columns) - // Negative values are normalized to 0 in the planner - let num_columns_to_hash = if *max_hash_columns == 0 { - input.num_columns() - } else { - (*max_hash_columns).min(input.num_columns()) - }; - let columns_to_hash: Vec = (0..num_columns_to_hash) - .map(|i| Arc::clone(input.column(i))) - .collect(); - - // Use identical seed as Spark hash partitioning. - let hashes_buf = &mut scratch.hashes_buf[..num_rows]; - hashes_buf.fill(42_u32); - - // Compute hash for selected columns - create_murmur3_hashes(&columns_to_hash, hashes_buf)?; - - // Assign partition IDs based on hash (same as hash partitioning) - let partition_ids = &mut scratch.partition_ids[..num_rows]; - hashes_buf.iter().enumerate().for_each(|(idx, hash)| { - partition_ids[idx] = - comet_partitioning::pmod(*hash, *num_output_partitions) as u32; - }); - - // We now have partition ids for every input row, map that to partition starts - // and partition indices to eventually write these rows to partition buffers. - scratch - .map_partition_ids_to_starts_and_indices(*num_output_partitions, num_rows); - - timer.stop(); - Ok::<(&Vec, &Vec), DataFusionError>(( - &scratch.partition_starts, - &scratch.partition_row_indices, - )) - }?; - - self.buffer_partitioned_batch_may_spill( - input, - partition_row_indices, - partition_starts, - ) - .await?; - self.scratch = scratch; - } - other => { - // this should be unreachable as long as the validation logic - // in the constructor is kept up-to-date - return Err(DataFusionError::NotImplemented(format!( - "Unsupported shuffle partitioning scheme {other:?}" - ))); - } - } - Ok(()) - } - - async fn buffer_partitioned_batch_may_spill( - &mut self, - input: RecordBatch, - partition_row_indices: &[u32], - partition_starts: &[u32], - ) -> datafusion::common::Result<()> { - let mut mem_growth: usize = input.get_array_memory_size(); - let buffered_partition_idx = self.buffered_batches.len() as u32; - self.buffered_batches.push(input); - - // partition_starts conceptually slices partition_row_indices into smaller slices, - // each slice contains the indices of rows in input that will go into the corresponding - // partition. The following loop iterates over the slices and put the row indices into - // the indices array of the corresponding partition. - for (partition_id, (&start, &end)) in partition_starts - .iter() - .tuple_windows() - .enumerate() - .filter(|(_, (start, end))| start < end) - { - let row_indices = &partition_row_indices[start as usize..end as usize]; - - // Put row indices for the current partition into the indices array of that partition. - // This indices array will be used for calling interleave_record_batch to produce - // shuffled batches. - let indices = &mut self.partition_indices[partition_id]; - let before_size = indices.allocated_size(); - indices.reserve(row_indices.len()); - for row_idx in row_indices { - indices.push((buffered_partition_idx, *row_idx)); - } - let after_size = indices.allocated_size(); - mem_growth += after_size.saturating_sub(before_size); - } - - if self.reservation.try_grow(mem_growth).is_err() { - self.spill()?; - } - - Ok(()) - } - - fn shuffle_write_partition( - partition_iter: &mut PartitionedBatchIterator, - shuffle_block_writer: &mut ShuffleBlockWriter, - output_data: &mut BufWriter, - encode_time: &Time, - write_time: &Time, - write_buffer_size: usize, - batch_size: usize, - ) -> datafusion::common::Result<()> { - let mut buf_batch_writer = BufBatchWriter::new( - shuffle_block_writer, - output_data, - write_buffer_size, - batch_size, - ); - for batch in partition_iter { - let batch = batch?; - buf_batch_writer.write(&batch, encode_time, write_time)?; - } - buf_batch_writer.flush(encode_time, write_time)?; - Ok(()) - } - - fn used(&self) -> usize { - self.reservation.size() - } - - fn spilled_bytes(&self) -> usize { - self.metrics.spilled_bytes.value() - } - - fn spill_count(&self) -> usize { - self.metrics.spill_count.value() - } - - fn data_size(&self) -> usize { - self.metrics.data_size.value() - } - - /// This function transfers the ownership of the buffered batches and partition indices from the - /// ShuffleRepartitioner to a new PartitionedBatches struct. The returned PartitionedBatches struct - /// can be used to produce shuffled batches. - fn partitioned_batches(&mut self) -> PartitionedBatchesProducer { - let num_output_partitions = self.partition_indices.len(); - let buffered_batches = std::mem::take(&mut self.buffered_batches); - // let indices = std::mem::take(&mut self.partition_indices); - let indices = std::mem::replace( - &mut self.partition_indices, - vec![vec![]; num_output_partitions], - ); - PartitionedBatchesProducer::new(buffered_batches, indices, self.batch_size) - } - - pub(crate) fn spill(&mut self) -> datafusion::common::Result<()> { - log::info!( - "ShuffleRepartitioner spilling shuffle data of {} to disk while inserting ({} time(s) so far)", - self.used(), - self.spill_count() - ); - - // we could always get a chance to free some memory as long as we are holding some - if self.buffered_batches.is_empty() { - return Ok(()); - } - - with_trace("shuffle_spill", self.tracing_enabled, || { - let num_output_partitions = self.partition_writers.len(); - let mut partitioned_batches = self.partitioned_batches(); - let mut spilled_bytes = 0; - - for partition_id in 0..num_output_partitions { - let partition_writer = &mut self.partition_writers[partition_id]; - let mut iter = partitioned_batches.produce(partition_id); - spilled_bytes += partition_writer.spill( - &mut iter, - &self.runtime, - &self.metrics, - self.write_buffer_size, - self.batch_size, - )?; - } - - self.reservation.free(); - self.metrics.spill_count.add(1); - self.metrics.spilled_bytes.add(spilled_bytes); - Ok(()) - }) - } - - #[cfg(test)] - pub(crate) fn partition_writers(&self) -> &[PartitionWriter] { - &self.partition_writers - } -} - -#[async_trait::async_trait] -impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { - /// Shuffles rows in input batch into corresponding partition buffer. - /// This function will slice input batch according to configured batch size and then - /// shuffle rows into corresponding partition buffer. - async fn insert_batch(&mut self, batch: RecordBatch) -> datafusion::common::Result<()> { - with_trace_async("shuffle_insert_batch", self.tracing_enabled, || async { - let start_time = Instant::now(); - let mut start = 0; - while start < batch.num_rows() { - let end = (start + self.batch_size).min(batch.num_rows()); - let batch = batch.slice(start, end - start); - self.partitioning_batch(batch).await?; - start = end; - } - self.metrics.input_batches.add(1); - self.metrics - .baseline - .elapsed_compute() - .add_duration(start_time.elapsed()); - Ok(()) - }) - .await - } - - /// Writes buffered shuffled record batches into Arrow IPC bytes. - fn shuffle_write(&mut self) -> datafusion::common::Result<()> { - with_trace("shuffle_write", self.tracing_enabled, || { - let start_time = Instant::now(); - - let mut partitioned_batches = self.partitioned_batches(); - let num_output_partitions = self.partition_indices.len(); - let mut offsets = vec![0; num_output_partitions + 1]; - - let data_file = self.output_data_file.clone(); - let index_file = self.output_index_file.clone(); - - let output_data = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(data_file) - .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; - - let mut output_data = BufWriter::new(output_data); - - #[allow(clippy::needless_range_loop)] - for i in 0..num_output_partitions { - offsets[i] = output_data.stream_position()?; - - // if we wrote a spill file for this partition then copy the - // contents into the shuffle file - if let Some(spill_path) = self.partition_writers[i].path() { - let mut spill_file = BufReader::new(File::open(spill_path)?); - let mut write_timer = self.metrics.write_time.timer(); - std::io::copy(&mut spill_file, &mut output_data)?; - write_timer.stop(); - } - - // Write in memory batches to output data file - let mut partition_iter = partitioned_batches.produce(i); - Self::shuffle_write_partition( - &mut partition_iter, - &mut self.shuffle_block_writer, - &mut output_data, - &self.metrics.encode_time, - &self.metrics.write_time, - self.write_buffer_size, - self.batch_size, - )?; - } - - let mut write_timer = self.metrics.write_time.timer(); - output_data.flush()?; - write_timer.stop(); - - // add one extra offset at last to ease partition length computation - offsets[num_output_partitions] = output_data.stream_position()?; - - let mut write_timer = self.metrics.write_time.timer(); - let mut output_index = - BufWriter::new(File::create(index_file).map_err(|e| { - DataFusionError::Execution(format!("shuffle write error: {e:?}")) - })?); - for offset in offsets { - output_index.write_all(&(offset as i64).to_le_bytes()[..])?; - } - output_index.flush()?; - write_timer.stop(); - - self.metrics - .baseline - .elapsed_compute() - .add_duration(start_time.elapsed()); - - Ok(()) - }) - } -} - -impl Debug for MultiPartitionShuffleRepartitioner { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("ShuffleRepartitioner") - .field("memory_used", &self.used()) - .field("spilled_bytes", &self.spilled_bytes()) - .field("spilled_count", &self.spill_count()) - .field("data_size", &self.data_size()) - .finish() - } -} diff --git a/native/core/src/execution/shuffle/partitioners/partitioned_batch_iterator.rs b/native/core/src/execution/shuffle/partitioners/partitioned_batch_iterator.rs deleted file mode 100644 index 77010938cd..0000000000 --- a/native/core/src/execution/shuffle/partitioners/partitioned_batch_iterator.rs +++ /dev/null @@ -1,110 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::RecordBatch; -use arrow::compute::interleave_record_batch; -use datafusion::common::DataFusionError; - -/// A helper struct to produce shuffled batches. -/// This struct takes ownership of the buffered batches and partition indices from the -/// ShuffleRepartitioner, and provides an iterator over the batches in the specified partitions. -pub(super) struct PartitionedBatchesProducer { - buffered_batches: Vec, - partition_indices: Vec>, - batch_size: usize, -} - -impl PartitionedBatchesProducer { - pub(super) fn new( - buffered_batches: Vec, - indices: Vec>, - batch_size: usize, - ) -> Self { - Self { - partition_indices: indices, - buffered_batches, - batch_size, - } - } - - pub(super) fn produce(&mut self, partition_id: usize) -> PartitionedBatchIterator<'_> { - PartitionedBatchIterator::new( - &self.partition_indices[partition_id], - &self.buffered_batches, - self.batch_size, - ) - } -} - -pub(crate) struct PartitionedBatchIterator<'a> { - record_batches: Vec<&'a RecordBatch>, - batch_size: usize, - indices: Vec<(usize, usize)>, - pos: usize, -} - -impl<'a> PartitionedBatchIterator<'a> { - fn new( - indices: &'a [(u32, u32)], - buffered_batches: &'a [RecordBatch], - batch_size: usize, - ) -> Self { - if indices.is_empty() { - // Avoid unnecessary allocations when the partition is empty - return Self { - record_batches: vec![], - batch_size, - indices: vec![], - pos: 0, - }; - } - let record_batches = buffered_batches.iter().collect::>(); - let current_indices = indices - .iter() - .map(|(i_batch, i_row)| (*i_batch as usize, *i_row as usize)) - .collect::>(); - Self { - record_batches, - batch_size, - indices: current_indices, - pos: 0, - } - } -} - -impl Iterator for PartitionedBatchIterator<'_> { - type Item = datafusion::common::Result; - - fn next(&mut self) -> Option { - if self.pos >= self.indices.len() { - return None; - } - - let indices_end = std::cmp::min(self.pos + self.batch_size, self.indices.len()); - let indices = &self.indices[self.pos..indices_end]; - match interleave_record_batch(&self.record_batches, indices) { - Ok(batch) => { - self.pos = indices_end; - Some(Ok(batch)) - } - Err(e) => Some(Err(DataFusionError::ArrowError( - Box::from(e), - Some(DataFusionError::get_back_trace()), - ))), - } - } -} diff --git a/native/core/src/execution/shuffle/partitioners/single_partition.rs b/native/core/src/execution/shuffle/partitioners/single_partition.rs deleted file mode 100644 index eeca4458cc..0000000000 --- a/native/core/src/execution/shuffle/partitioners/single_partition.rs +++ /dev/null @@ -1,192 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::execution::shuffle::metrics::ShufflePartitionerMetrics; -use crate::execution::shuffle::partitioners::ShufflePartitioner; -use crate::execution::shuffle::writers::BufBatchWriter; -use crate::execution::shuffle::{CompressionCodec, ShuffleBlockWriter}; -use arrow::array::RecordBatch; -use arrow::datatypes::SchemaRef; -use datafusion::common::DataFusionError; -use std::fs::{File, OpenOptions}; -use std::io::{BufWriter, Write}; -use tokio::time::Instant; - -/// A partitioner that writes all shuffle data to a single file and a single index file -pub(crate) struct SinglePartitionShufflePartitioner { - // output_data_file: File, - output_data_writer: BufBatchWriter, - output_index_path: String, - /// Batches that are smaller than the batch size and to be concatenated - buffered_batches: Vec, - /// Number of rows in the concatenating batches - num_buffered_rows: usize, - /// Metrics for the repartitioner - metrics: ShufflePartitionerMetrics, - /// The configured batch size - batch_size: usize, -} - -impl SinglePartitionShufflePartitioner { - pub(crate) fn try_new( - output_data_path: String, - output_index_path: String, - schema: SchemaRef, - metrics: ShufflePartitionerMetrics, - batch_size: usize, - codec: CompressionCodec, - write_buffer_size: usize, - ) -> datafusion::common::Result { - let shuffle_block_writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone())?; - - let output_data_file = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(output_data_path)?; - - let output_data_writer = BufBatchWriter::new( - shuffle_block_writer, - output_data_file, - write_buffer_size, - batch_size, - ); - - Ok(Self { - output_data_writer, - output_index_path, - buffered_batches: vec![], - num_buffered_rows: 0, - metrics, - batch_size, - }) - } - - /// Add a batch to the buffer of the partitioner, these buffered batches will be concatenated - /// and written to the output data file when the number of rows in the buffer reaches the batch size. - fn add_buffered_batch(&mut self, batch: RecordBatch) { - self.num_buffered_rows += batch.num_rows(); - self.buffered_batches.push(batch); - } - - /// Consumes buffered batches and return a concatenated batch if successful - fn concat_buffered_batches(&mut self) -> datafusion::common::Result> { - if self.buffered_batches.is_empty() { - Ok(None) - } else if self.buffered_batches.len() == 1 { - let batch = self.buffered_batches.remove(0); - self.num_buffered_rows = 0; - Ok(Some(batch)) - } else { - let schema = &self.buffered_batches[0].schema(); - match arrow::compute::concat_batches(schema, self.buffered_batches.iter()) { - Ok(concatenated) => { - self.buffered_batches.clear(); - self.num_buffered_rows = 0; - Ok(Some(concatenated)) - } - Err(e) => Err(DataFusionError::ArrowError( - Box::from(e), - Some(DataFusionError::get_back_trace()), - )), - } - } - } -} - -#[async_trait::async_trait] -impl ShufflePartitioner for SinglePartitionShufflePartitioner { - async fn insert_batch(&mut self, batch: RecordBatch) -> datafusion::common::Result<()> { - let start_time = Instant::now(); - let num_rows = batch.num_rows(); - - if num_rows > 0 { - self.metrics.data_size.add(batch.get_array_memory_size()); - self.metrics.baseline.record_output(num_rows); - - if num_rows >= self.batch_size || num_rows + self.num_buffered_rows > self.batch_size { - let concatenated_batch = self.concat_buffered_batches()?; - - // Write the concatenated buffered batch - if let Some(batch) = concatenated_batch { - self.output_data_writer.write( - &batch, - &self.metrics.encode_time, - &self.metrics.write_time, - )?; - } - - if num_rows >= self.batch_size { - // Write the new batch - self.output_data_writer.write( - &batch, - &self.metrics.encode_time, - &self.metrics.write_time, - )?; - } else { - // Add the new batch to the buffer - self.add_buffered_batch(batch); - } - } else { - self.add_buffered_batch(batch); - } - } - - self.metrics.input_batches.add(1); - self.metrics - .baseline - .elapsed_compute() - .add_duration(start_time.elapsed()); - Ok(()) - } - - fn shuffle_write(&mut self) -> datafusion::common::Result<()> { - let start_time = Instant::now(); - let concatenated_batch = self.concat_buffered_batches()?; - - // Write the concatenated buffered batch - if let Some(batch) = concatenated_batch { - self.output_data_writer.write( - &batch, - &self.metrics.encode_time, - &self.metrics.write_time, - )?; - } - self.output_data_writer - .flush(&self.metrics.encode_time, &self.metrics.write_time)?; - - // Write index file. It should only contain 2 entries: 0 and the total number of bytes written - let index_file = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(self.output_index_path.clone()) - .map_err(|e| DataFusionError::Execution(format!("shuffle write error: {e:?}")))?; - let mut index_buf_writer = BufWriter::new(index_file); - let data_file_length = self.output_data_writer.writer_stream_position()?; - for offset in [0, data_file_length] { - index_buf_writer.write_all(&(offset as i64).to_le_bytes()[..])?; - } - index_buf_writer.flush()?; - - self.metrics - .baseline - .elapsed_compute() - .add_duration(start_time.elapsed()); - Ok(()) - } -} diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs deleted file mode 100644 index fe1bf0fccf..0000000000 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ /dev/null @@ -1,696 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines the External shuffle repartition plan. - -use crate::execution::shuffle::metrics::ShufflePartitionerMetrics; -use crate::execution::shuffle::partitioners::{ - MultiPartitionShuffleRepartitioner, ShufflePartitioner, SinglePartitionShufflePartitioner, -}; -use crate::execution::shuffle::{CometPartitioning, CompressionCodec}; -use crate::execution::tracing::with_trace_async; -use async_trait::async_trait; -use datafusion::common::exec_datafusion_err; -use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; -use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; -use datafusion::physical_plan::EmptyRecordBatchStream; -use datafusion::{ - arrow::{datatypes::SchemaRef, error::ArrowError}, - error::Result, - execution::context::TaskContext, - physical_plan::{ - metrics::{ExecutionPlanMetricsSet, MetricsSet}, - stream::RecordBatchStreamAdapter, - DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, - Statistics, - }, -}; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; -use std::{ - any::Any, - fmt, - fmt::{Debug, Formatter}, - sync::Arc, -}; - -/// The shuffle writer operator maps each input partition to M output partitions based on a -/// partitioning scheme. No guarantees are made about the order of the resulting partitions. -#[derive(Debug)] -pub struct ShuffleWriterExec { - /// Input execution plan - input: Arc, - /// Partitioning scheme to use - partitioning: CometPartitioning, - /// Output data file path - output_data_file: String, - /// Output index file path - output_index_file: String, - /// Metrics - metrics: ExecutionPlanMetricsSet, - /// Cache for expensive-to-compute plan properties - cache: PlanProperties, - /// The compression codec to use when compressing shuffle blocks - codec: CompressionCodec, - tracing_enabled: bool, - /// Size of the write buffer in bytes - write_buffer_size: usize, -} - -impl ShuffleWriterExec { - /// Create a new ShuffleWriterExec - #[allow(clippy::too_many_arguments)] - pub fn try_new( - input: Arc, - partitioning: CometPartitioning, - codec: CompressionCodec, - output_data_file: String, - output_index_file: String, - tracing_enabled: bool, - write_buffer_size: usize, - ) -> Result { - let cache = PlanProperties::new( - EquivalenceProperties::new(Arc::clone(&input.schema())), - Partitioning::UnknownPartitioning(1), - EmissionType::Final, - Boundedness::Bounded, - ); - - Ok(ShuffleWriterExec { - input, - partitioning, - metrics: ExecutionPlanMetricsSet::new(), - output_data_file, - output_index_file, - cache, - codec, - tracing_enabled, - write_buffer_size, - }) - } -} - -impl DisplayAs for ShuffleWriterExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "ShuffleWriterExec: partitioning={:?}, compression={:?}", - self.partitioning, self.codec - ) - } - DisplayFormatType::TreeRender => unimplemented!(), - } - } -} - -#[async_trait] -impl ExecutionPlan for ShuffleWriterExec { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "ShuffleWriterExec" - } - - fn metrics(&self) -> Option { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Result { - self.input.partition_statistics(None) - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef { - self.input.schema() - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - match children.len() { - 1 => Ok(Arc::new(ShuffleWriterExec::try_new( - Arc::clone(&children[0]), - self.partitioning.clone(), - self.codec.clone(), - self.output_data_file.clone(), - self.output_index_file.clone(), - self.tracing_enabled, - self.write_buffer_size, - )?)), - _ => panic!("ShuffleWriterExec wrong number of children"), - } - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - let input = self.input.execute(partition, Arc::clone(&context))?; - let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::once( - external_shuffle( - input, - partition, - self.output_data_file.clone(), - self.output_index_file.clone(), - self.partitioning.clone(), - metrics, - context, - self.codec.clone(), - self.tracing_enabled, - self.write_buffer_size, - ) - .map_err(|e| ArrowError::ExternalError(Box::new(e))), - ) - .try_flatten(), - ))) - } -} - -#[allow(clippy::too_many_arguments)] -async fn external_shuffle( - mut input: SendableRecordBatchStream, - partition: usize, - output_data_file: String, - output_index_file: String, - partitioning: CometPartitioning, - metrics: ShufflePartitionerMetrics, - context: Arc, - codec: CompressionCodec, - tracing_enabled: bool, - write_buffer_size: usize, -) -> Result { - with_trace_async("external_shuffle", tracing_enabled, || async { - let schema = input.schema(); - - let mut repartitioner: Box = match &partitioning { - any if any.partition_count() == 1 => { - Box::new(SinglePartitionShufflePartitioner::try_new( - output_data_file, - output_index_file, - Arc::clone(&schema), - metrics, - context.session_config().batch_size(), - codec, - write_buffer_size, - )?) - } - _ => Box::new(MultiPartitionShuffleRepartitioner::try_new( - partition, - output_data_file, - output_index_file, - Arc::clone(&schema), - partitioning, - metrics, - context.runtime_env(), - context.session_config().batch_size(), - codec, - tracing_enabled, - write_buffer_size, - )?), - }; - - while let Some(batch) = input.next().await { - // Await the repartitioner to insert the batch and shuffle the rows - // into the corresponding partition buffer. - // Otherwise, pull the next batch from the input stream might overwrite the - // current batch in the repartitioner. - repartitioner - .insert_batch(batch?) - .await - .map_err(|err| exec_datafusion_err!("Error inserting batch: {err}"))?; - } - - repartitioner - .shuffle_write() - .map_err(|err| exec_datafusion_err!("Error in shuffle write: {err}"))?; - - // shuffle writer always has empty output - Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone(&schema))) as SendableRecordBatchStream) - }) - .await -} - -#[cfg(test)] -mod test { - use super::*; - use crate::execution::shuffle::{read_ipc_compressed, ShuffleBlockWriter}; - use arrow::array::{Array, StringArray, StringBuilder}; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use arrow::row::{RowConverter, SortField}; - use datafusion::datasource::memory::MemorySourceConfig; - use datafusion::datasource::source::DataSourceExec; - use datafusion::execution::config::SessionConfig; - use datafusion::execution::runtime_env::{RuntimeEnv, RuntimeEnvBuilder}; - use datafusion::physical_expr::expressions::{col, Column}; - use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; - use datafusion::physical_plan::common::collect; - use datafusion::physical_plan::metrics::Time; - use datafusion::prelude::SessionContext; - use itertools::Itertools; - use std::io::Cursor; - use tokio::runtime::Runtime; - - #[test] - #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn roundtrip_ipc() { - let batch = create_batch(8192); - for codec in &[ - CompressionCodec::None, - CompressionCodec::Zstd(1), - CompressionCodec::Snappy, - CompressionCodec::Lz4Frame, - ] { - let mut output = vec![]; - let mut cursor = Cursor::new(&mut output); - let writer = - ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone()).unwrap(); - let length = writer - .write_batch(&batch, &mut cursor, &Time::default()) - .unwrap(); - assert_eq!(length, output.len()); - - let ipc_without_length_prefix = &output[16..]; - let batch2 = read_ipc_compressed(ipc_without_length_prefix).unwrap(); - assert_eq!(batch, batch2); - } - } - - #[test] - #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn test_single_partition_shuffle_writer() { - shuffle_write_test(1000, 100, 1, None); - shuffle_write_test(10000, 10, 1, None); - } - - #[test] - #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn test_insert_larger_batch() { - shuffle_write_test(10000, 1, 16, None); - } - - #[test] - #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn test_insert_smaller_batch() { - shuffle_write_test(1000, 1, 16, None); - shuffle_write_test(1000, 10, 16, None); - } - - #[test] - #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn test_large_number_of_partitions() { - shuffle_write_test(10000, 10, 200, Some(10 * 1024 * 1024)); - shuffle_write_test(10000, 10, 2000, Some(10 * 1024 * 1024)); - } - - #[test] - #[cfg_attr(miri, ignore)] // miri can't call foreign function `ZSTD_createCCtx` - fn test_large_number_of_partitions_spilling() { - shuffle_write_test(10000, 100, 200, Some(10 * 1024 * 1024)); - } - - #[tokio::test] - async fn shuffle_partitioner_memory() { - let batch = create_batch(900); - assert_eq!(8316, batch.get_array_memory_size()); // Not stable across Arrow versions - - let memory_limit = 512 * 1024; - let num_partitions = 2; - let runtime_env = create_runtime(memory_limit); - let metrics_set = ExecutionPlanMetricsSet::new(); - let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new( - 0, - "/tmp/data.out".to_string(), - "/tmp/index.out".to_string(), - batch.schema(), - CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions), - ShufflePartitionerMetrics::new(&metrics_set, 0), - runtime_env, - 1024, - CompressionCodec::Lz4Frame, - false, - 1024 * 1024, // write_buffer_size: 1MB default - ) - .unwrap(); - - repartitioner.insert_batch(batch.clone()).await.unwrap(); - - { - let partition_writers = repartitioner.partition_writers(); - assert_eq!(partition_writers.len(), 2); - - assert!(!partition_writers[0].has_spill_file()); - assert!(!partition_writers[1].has_spill_file()); - } - - repartitioner.spill().unwrap(); - - // after spill, there should be spill files - { - let partition_writers = repartitioner.partition_writers(); - assert!(partition_writers[0].has_spill_file()); - assert!(partition_writers[1].has_spill_file()); - } - - // insert another batch after spilling - repartitioner.insert_batch(batch.clone()).await.unwrap(); - } - - fn create_runtime(memory_limit: usize) -> Arc { - Arc::new( - RuntimeEnvBuilder::new() - .with_memory_limit(memory_limit, 1.0) - .build() - .unwrap(), - ) - } - - fn shuffle_write_test( - batch_size: usize, - num_batches: usize, - num_partitions: usize, - memory_limit: Option, - ) { - let batch = create_batch(batch_size); - - let lex_ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( - col("a", batch.schema().as_ref()).unwrap(), - )]) - .unwrap(); - - let sort_fields: Vec = batch - .columns() - .iter() - .zip(&lex_ordering) - .map(|(array, sort_expr)| { - SortField::new_with_options(array.data_type().clone(), sort_expr.options) - }) - .collect(); - let row_converter = RowConverter::new(sort_fields).unwrap(); - - let owned_rows = if num_partitions == 1 { - vec![] - } else { - // Determine range boundaries based on create_batch implementation. We just divide the - // domain of values in the batch equally to find partition bounds. - let bounds_strings = { - let mut boundaries = Vec::with_capacity(num_partitions - 1); - let step = batch_size as f64 / num_partitions as f64; - - for i in 1..(num_partitions) { - boundaries.push(Some((step * i as f64).round().to_string())); - } - boundaries - }; - let bounds_array: Arc = Arc::new(StringArray::from(bounds_strings)); - let bounds_rows = row_converter - .convert_columns(vec![bounds_array].as_slice()) - .unwrap(); - - let owned_rows_vec = bounds_rows.iter().map(|row| row.owned()).collect_vec(); - owned_rows_vec - }; - - for partitioning in [ - CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions), - CometPartitioning::RangePartitioning( - lex_ordering, - num_partitions, - Arc::new(row_converter), - owned_rows, - ), - CometPartitioning::RoundRobin(num_partitions, 0), - ] { - let batches = (0..num_batches).map(|_| batch.clone()).collect::>(); - - let partitions = &[batches]; - let exec = ShuffleWriterExec::try_new( - Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(), - ))), - partitioning, - CompressionCodec::Zstd(1), - "/tmp/data.out".to_string(), - "/tmp/index.out".to_string(), - false, - 1024 * 1024, // write_buffer_size: 1MB default - ) - .unwrap(); - - // 10MB memory should be enough for running this test - let config = SessionConfig::new(); - let mut runtime_env_builder = RuntimeEnvBuilder::new(); - runtime_env_builder = match memory_limit { - Some(limit) => runtime_env_builder.with_memory_limit(limit, 1.0), - None => runtime_env_builder, - }; - let runtime_env = Arc::new(runtime_env_builder.build().unwrap()); - let ctx = SessionContext::new_with_config_rt(config, runtime_env); - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx).unwrap(); - let rt = Runtime::new().unwrap(); - rt.block_on(collect(stream)).unwrap(); - } - } - - fn create_batch(batch_size: usize) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); - let mut b = StringBuilder::new(); - for i in 0..batch_size { - b.append_value(format!("{i}")); - } - let array = b.finish(); - RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(array)]).unwrap() - } - - #[test] - #[cfg_attr(miri, ignore)] - fn test_round_robin_deterministic() { - // Test that round robin partitioning produces identical results when run multiple times - use std::fs; - use std::io::Read; - - let batch_size = 1000; - let num_batches = 10; - let num_partitions = 8; - - let batch = create_batch(batch_size); - let batches = (0..num_batches).map(|_| batch.clone()).collect::>(); - - // Run shuffle twice and compare results - for run in 0..2 { - let data_file = format!("/tmp/rr_data_{}.out", run); - let index_file = format!("/tmp/rr_index_{}.out", run); - - let partitions = std::slice::from_ref(&batches); - let exec = ShuffleWriterExec::try_new( - Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(partitions, batch.schema(), None).unwrap(), - ))), - CometPartitioning::RoundRobin(num_partitions, 0), - CompressionCodec::Zstd(1), - data_file.clone(), - index_file.clone(), - false, - 1024 * 1024, - ) - .unwrap(); - - let config = SessionConfig::new(); - let runtime_env = Arc::new( - RuntimeEnvBuilder::new() - .with_memory_limit(10 * 1024 * 1024, 1.0) - .build() - .unwrap(), - ); - let session_ctx = Arc::new(SessionContext::new_with_config_rt(config, runtime_env)); - let task_ctx = Arc::new(TaskContext::from(session_ctx.as_ref())); - - // Execute the shuffle - futures::executor::block_on(async { - let mut stream = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); - while stream.next().await.is_some() {} - }); - - if run == 1 { - // Compare data files - let mut data0 = Vec::new(); - fs::File::open("/tmp/rr_data_0.out") - .unwrap() - .read_to_end(&mut data0) - .unwrap(); - let mut data1 = Vec::new(); - fs::File::open("/tmp/rr_data_1.out") - .unwrap() - .read_to_end(&mut data1) - .unwrap(); - assert_eq!( - data0, data1, - "Round robin shuffle data should be identical across runs" - ); - - // Compare index files - let mut index0 = Vec::new(); - fs::File::open("/tmp/rr_index_0.out") - .unwrap() - .read_to_end(&mut index0) - .unwrap(); - let mut index1 = Vec::new(); - fs::File::open("/tmp/rr_index_1.out") - .unwrap() - .read_to_end(&mut index1) - .unwrap(); - assert_eq!( - index0, index1, - "Round robin shuffle index should be identical across runs" - ); - } - } - - // Clean up - let _ = fs::remove_file("/tmp/rr_data_0.out"); - let _ = fs::remove_file("/tmp/rr_index_0.out"); - let _ = fs::remove_file("/tmp/rr_data_1.out"); - let _ = fs::remove_file("/tmp/rr_index_1.out"); - } - - /// Test that batch coalescing in BufBatchWriter reduces output size by - /// writing fewer, larger IPC blocks instead of many small ones. - #[test] - #[cfg_attr(miri, ignore)] - fn test_batch_coalescing_reduces_size() { - use crate::execution::shuffle::writers::BufBatchWriter; - use arrow::array::Int32Array; - - // Create a wide schema to amplify per-block schema overhead - let fields: Vec = (0..20) - .map(|i| Field::new(format!("col_{i}"), DataType::Int32, false)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create many small batches (50 rows each) - let small_batches: Vec = (0..100) - .map(|batch_idx| { - let columns: Vec> = (0..20) - .map(|col_idx| { - let values: Vec = (0..50) - .map(|row| batch_idx * 50 + row + col_idx * 1000) - .collect(); - Arc::new(Int32Array::from(values)) as Arc - }) - .collect(); - RecordBatch::try_new(Arc::clone(&schema), columns).unwrap() - }) - .collect(); - - let codec = CompressionCodec::Lz4Frame; - let encode_time = Time::default(); - let write_time = Time::default(); - - // Write with coalescing (batch_size=8192) - let mut coalesced_output = Vec::new(); - { - let mut writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); - let mut buf_writer = BufBatchWriter::new( - &mut writer, - Cursor::new(&mut coalesced_output), - 1024 * 1024, - 8192, - ); - for batch in &small_batches { - buf_writer.write(batch, &encode_time, &write_time).unwrap(); - } - buf_writer.flush(&encode_time, &write_time).unwrap(); - } - - // Write without coalescing (batch_size=1) - let mut uncoalesced_output = Vec::new(); - { - let mut writer = ShuffleBlockWriter::try_new(schema.as_ref(), codec.clone()).unwrap(); - let mut buf_writer = BufBatchWriter::new( - &mut writer, - Cursor::new(&mut uncoalesced_output), - 1024 * 1024, - 1, - ); - for batch in &small_batches { - buf_writer.write(batch, &encode_time, &write_time).unwrap(); - } - buf_writer.flush(&encode_time, &write_time).unwrap(); - } - - // Coalesced output should be smaller due to fewer IPC schema blocks - assert!( - coalesced_output.len() < uncoalesced_output.len(), - "Coalesced output ({} bytes) should be smaller than uncoalesced ({} bytes)", - coalesced_output.len(), - uncoalesced_output.len() - ); - - // Verify both roundtrip correctly by reading all IPC blocks - let coalesced_rows = read_all_ipc_blocks(&coalesced_output); - let uncoalesced_rows = read_all_ipc_blocks(&uncoalesced_output); - assert_eq!( - coalesced_rows, 5000, - "Coalesced should contain all 5000 rows" - ); - assert_eq!( - uncoalesced_rows, 5000, - "Uncoalesced should contain all 5000 rows" - ); - } - - /// Read all IPC blocks from a byte buffer written by BufBatchWriter/ShuffleBlockWriter, - /// returning the total number of rows. - fn read_all_ipc_blocks(data: &[u8]) -> usize { - let mut offset = 0; - let mut total_rows = 0; - while offset < data.len() { - // First 8 bytes are the IPC length (little-endian u64) - let ipc_length = - u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap()) as usize; - // Skip the 8-byte length prefix; the next 8 bytes are field_count + codec header - let block_start = offset + 8; - let block_end = block_start + ipc_length; - // read_ipc_compressed expects data starting after the 16-byte header - // (i.e., after length + field_count), at the codec tag - let ipc_data = &data[block_start + 8..block_end]; - let batch = read_ipc_compressed(ipc_data).unwrap(); - total_rows += batch.num_rows(); - offset = block_end; - } - total_rows - } -} diff --git a/native/core/src/execution/shuffle/spark_unsafe/list.rs b/native/core/src/execution/shuffle/spark_unsafe/list.rs deleted file mode 100644 index d9c93b1c6e..0000000000 --- a/native/core/src/execution/shuffle/spark_unsafe/list.rs +++ /dev/null @@ -1,487 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::{ - errors::CometError, - execution::shuffle::spark_unsafe::{ - map::append_map_elements, - row::{ - append_field, downcast_builder_ref, impl_primitive_accessors, SparkUnsafeObject, - SparkUnsafeRow, - }, - }, -}; -use arrow::array::{ - builder::{ - ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, - Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, - ListBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, - }, - MapBuilder, -}; -use arrow::datatypes::{DataType, TimeUnit}; - -/// Generates bulk append methods for primitive types in SparkUnsafeArray. -/// -/// # Safety invariants for all generated methods: -/// - `element_offset` points to contiguous element data of length `num_elements` -/// - `null_bitset_ptr()` returns a pointer to `ceil(num_elements/64)` i64 words -/// - These invariants are guaranteed by the SparkUnsafeArray layout from the JVM -macro_rules! impl_append_to_builder { - ($method_name:ident, $builder_type:ty, $element_type:ty) => { - pub(crate) fn $method_name(&self, builder: &mut $builder_type) { - let num_elements = self.num_elements; - if num_elements == 0 { - return; - } - - if NULLABLE { - let mut ptr = self.element_offset as *const $element_type; - let null_words = self.null_bitset_ptr(); - debug_assert!(!null_words.is_null(), "null_bitset_ptr is null"); - debug_assert!(!ptr.is_null(), "element_offset pointer is null"); - for idx in 0..num_elements { - // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements - let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; - - if is_null { - builder.append_null(); - } else { - // SAFETY: ptr is within element data bounds - builder.append_value(unsafe { ptr.read_unaligned() }); - } - // SAFETY: ptr stays within bounds, iterating num_elements times - ptr = unsafe { ptr.add(1) }; - } - } else { - // SAFETY: element_offset points to contiguous data of length num_elements - debug_assert!(self.element_offset != 0, "element_offset is null"); - let ptr = self.element_offset as *const $element_type; - // Use bulk copy when data is properly aligned, fall back to - // per-element unaligned reads otherwise - if (ptr as usize).is_multiple_of(std::mem::align_of::<$element_type>()) { - let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) }; - builder.append_slice(slice); - } else { - let mut ptr = ptr; - for _ in 0..num_elements { - builder.append_value(unsafe { ptr.read_unaligned() }); - ptr = unsafe { ptr.add(1) }; - } - } - } - } - }; -} - -pub struct SparkUnsafeArray { - row_addr: i64, - num_elements: usize, - element_offset: i64, -} - -impl SparkUnsafeObject for SparkUnsafeArray { - #[inline] - fn get_row_addr(&self) -> i64 { - self.row_addr - } - - #[inline] - fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 { - (self.element_offset + (index * element_size) as i64) as *const u8 - } - - // SparkUnsafeArray base address may be unaligned when nested within a row's variable-length - // region, so we must use ptr::read_unaligned() for all typed accesses. - impl_primitive_accessors!(read_unaligned); -} - -impl SparkUnsafeArray { - /// Creates a `SparkUnsafeArray` which points to the given address and size in bytes. - pub fn new(addr: i64) -> Self { - // SAFETY: addr points to valid Spark UnsafeArray data from the JVM. - // The first 8 bytes contain the element count as a little-endian i64. - debug_assert!(addr != 0, "SparkUnsafeArray::new: null address"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; - let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); - - if num_elements < 0 { - panic!("Negative number of elements: {num_elements}"); - } - - if num_elements > i32::MAX as i64 { - panic!("Number of elements should <= i32::MAX: {num_elements}"); - } - - Self { - row_addr: addr, - num_elements: num_elements as usize, - element_offset: addr + Self::get_header_portion_in_bytes(num_elements), - } - } - - pub(crate) fn get_num_elements(&self) -> usize { - self.num_elements - } - - /// Returns the size of array header in bytes. - #[inline] - const fn get_header_portion_in_bytes(num_fields: i64) -> i64 { - 8 + ((num_fields + 63) / 64) * 8 - } - - /// Returns true if the null bit at the given index of the array is set. - #[inline] - pub(crate) fn is_null_at(&self, index: usize) -> bool { - // SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts - // at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures - // index < num_elements, so word_offset is within the bitset region. - debug_assert!( - index < self.num_elements, - "is_null_at: index {index} >= num_elements {}", - self.num_elements - ); - unsafe { - let mask: i64 = 1i64 << (index & 0x3f); - let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64; - let word: i64 = word_offset.read_unaligned(); - (word & mask) != 0 - } - } - - /// Returns the null bitset pointer (starts at row_addr + 8). - #[inline] - fn null_bitset_ptr(&self) -> *const i64 { - (self.row_addr + 8) as *const i64 - } - - /// Checks whether the null bit at `idx` is set in the given null bitset pointer. - /// - /// # Safety - /// `null_words` must point to at least `ceil((idx+1)/64)` i64 words. - #[inline] - unsafe fn is_null_in_bitset(null_words: *const i64, idx: usize) -> bool { - let word_idx = idx >> 6; - let bit_idx = idx & 0x3f; - (null_words.add(word_idx).read_unaligned() & (1i64 << bit_idx)) != 0 - } - - impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32); - impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64); - impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16); - impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8); - impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32); - impl_append_to_builder!(append_doubles_to_builder, Float64Builder, f64); - - /// Bulk append boolean values to builder. - /// Booleans are stored as 1 byte each in SparkUnsafeArray, requiring special handling. - pub(crate) fn append_booleans_to_builder( - &self, - builder: &mut BooleanBuilder, - ) { - let num_elements = self.num_elements; - if num_elements == 0 { - return; - } - - let mut ptr = self.element_offset as *const u8; - debug_assert!( - !ptr.is_null(), - "append_booleans: element_offset pointer is null" - ); - - if NULLABLE { - let null_words = self.null_bitset_ptr(); - debug_assert!( - !null_words.is_null(), - "append_booleans: null_bitset_ptr is null" - ); - for idx in 0..num_elements { - // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements - let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; - - if is_null { - builder.append_null(); - } else { - // SAFETY: ptr is within element data bounds - builder.append_value(unsafe { *ptr != 0 }); - } - // SAFETY: ptr stays within bounds, iterating num_elements times - ptr = unsafe { ptr.add(1) }; - } - } else { - for _ in 0..num_elements { - // SAFETY: ptr is within element data bounds - builder.append_value(unsafe { *ptr != 0 }); - ptr = unsafe { ptr.add(1) }; - } - } - } - - /// Bulk append timestamp values to builder (stored as i64 microseconds). - pub(crate) fn append_timestamps_to_builder( - &self, - builder: &mut TimestampMicrosecondBuilder, - ) { - let num_elements = self.num_elements; - if num_elements == 0 { - return; - } - - if NULLABLE { - let mut ptr = self.element_offset as *const i64; - let null_words = self.null_bitset_ptr(); - debug_assert!( - !null_words.is_null(), - "append_timestamps: null_bitset_ptr is null" - ); - debug_assert!( - !ptr.is_null(), - "append_timestamps: element_offset pointer is null" - ); - for idx in 0..num_elements { - // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements - let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; - - if is_null { - builder.append_null(); - } else { - // SAFETY: ptr is within element data bounds - builder.append_value(unsafe { ptr.read_unaligned() }); - } - // SAFETY: ptr stays within bounds, iterating num_elements times - ptr = unsafe { ptr.add(1) }; - } - } else { - // SAFETY: element_offset points to contiguous i64 data of length num_elements - debug_assert!( - self.element_offset != 0, - "append_timestamps: element_offset is null" - ); - let ptr = self.element_offset as *const i64; - if (ptr as usize).is_multiple_of(std::mem::align_of::()) { - let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) }; - builder.append_slice(slice); - } else { - let mut ptr = ptr; - for _ in 0..num_elements { - builder.append_value(unsafe { ptr.read_unaligned() }); - ptr = unsafe { ptr.add(1) }; - } - } - } - } - - /// Bulk append date values to builder (stored as i32 days since epoch). - pub(crate) fn append_dates_to_builder( - &self, - builder: &mut Date32Builder, - ) { - let num_elements = self.num_elements; - if num_elements == 0 { - return; - } - - if NULLABLE { - let mut ptr = self.element_offset as *const i32; - let null_words = self.null_bitset_ptr(); - debug_assert!( - !null_words.is_null(), - "append_dates: null_bitset_ptr is null" - ); - debug_assert!( - !ptr.is_null(), - "append_dates: element_offset pointer is null" - ); - for idx in 0..num_elements { - // SAFETY: null_words has ceil(num_elements/64) words, idx < num_elements - let is_null = unsafe { Self::is_null_in_bitset(null_words, idx) }; - - if is_null { - builder.append_null(); - } else { - // SAFETY: ptr is within element data bounds - builder.append_value(unsafe { ptr.read_unaligned() }); - } - // SAFETY: ptr stays within bounds, iterating num_elements times - ptr = unsafe { ptr.add(1) }; - } - } else { - // SAFETY: element_offset points to contiguous i32 data of length num_elements - debug_assert!( - self.element_offset != 0, - "append_dates: element_offset is null" - ); - let ptr = self.element_offset as *const i32; - if (ptr as usize).is_multiple_of(std::mem::align_of::()) { - let slice = unsafe { std::slice::from_raw_parts(ptr, num_elements) }; - builder.append_slice(slice); - } else { - let mut ptr = ptr; - for _ in 0..num_elements { - builder.append_value(unsafe { ptr.read_unaligned() }); - ptr = unsafe { ptr.add(1) }; - } - } - } - } -} - -pub fn append_to_builder( - data_type: &DataType, - builder: &mut dyn ArrayBuilder, - array: &SparkUnsafeArray, -) -> Result<(), CometError> { - macro_rules! add_values { - ($builder_type:ty, $add_value:expr, $add_null:expr) => { - let builder = downcast_builder_ref!($builder_type, builder); - for idx in 0..array.get_num_elements() { - if NULLABLE && array.is_null_at(idx) { - $add_null(builder); - } else { - $add_value(builder, array, idx); - } - } - }; - } - - match data_type { - DataType::Boolean => { - let builder = downcast_builder_ref!(BooleanBuilder, builder); - array.append_booleans_to_builder::(builder); - } - DataType::Int8 => { - let builder = downcast_builder_ref!(Int8Builder, builder); - array.append_bytes_to_builder::(builder); - } - DataType::Int16 => { - let builder = downcast_builder_ref!(Int16Builder, builder); - array.append_shorts_to_builder::(builder); - } - DataType::Int32 => { - let builder = downcast_builder_ref!(Int32Builder, builder); - array.append_ints_to_builder::(builder); - } - DataType::Int64 => { - let builder = downcast_builder_ref!(Int64Builder, builder); - array.append_longs_to_builder::(builder); - } - DataType::Float32 => { - let builder = downcast_builder_ref!(Float32Builder, builder); - array.append_floats_to_builder::(builder); - } - DataType::Float64 => { - let builder = downcast_builder_ref!(Float64Builder, builder); - array.append_doubles_to_builder::(builder); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let builder = downcast_builder_ref!(TimestampMicrosecondBuilder, builder); - array.append_timestamps_to_builder::(builder); - } - DataType::Date32 => { - let builder = downcast_builder_ref!(Date32Builder, builder); - array.append_dates_to_builder::(builder); - } - DataType::Binary => { - add_values!( - BinaryBuilder, - |builder: &mut BinaryBuilder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_binary(idx)), - |builder: &mut BinaryBuilder| builder.append_null() - ); - } - DataType::Utf8 => { - add_values!( - StringBuilder, - |builder: &mut StringBuilder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_string(idx)), - |builder: &mut StringBuilder| builder.append_null() - ); - } - DataType::List(field) => { - let builder = downcast_builder_ref!(ListBuilder>, builder); - for idx in 0..array.get_num_elements() { - if NULLABLE && array.is_null_at(idx) { - builder.append_null(); - } else { - let nested_array = array.get_array(idx); - append_list_element(field.data_type(), builder, &nested_array)?; - }; - } - } - DataType::Struct(fields) => { - let builder = downcast_builder_ref!(StructBuilder, builder); - for idx in 0..array.get_num_elements() { - let nested_row = if NULLABLE && array.is_null_at(idx) { - builder.append_null(); - SparkUnsafeRow::default() - } else { - builder.append(true); - array.get_struct(idx, fields.len()) - }; - - for (field_idx, field) in fields.into_iter().enumerate() { - append_field(field.data_type(), builder, &nested_row, field_idx)?; - } - } - } - DataType::Decimal128(p, _) => { - add_values!( - Decimal128Builder, - |builder: &mut Decimal128Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_decimal(idx, *p)), - |builder: &mut Decimal128Builder| builder.append_null() - ); - } - DataType::Map(field, _) => { - let builder = downcast_builder_ref!( - MapBuilder, Box>, - builder - ); - for idx in 0..array.get_num_elements() { - if NULLABLE && array.is_null_at(idx) { - builder.append(false)?; - } else { - let nested_map = array.get_map(idx); - append_map_elements(field, builder, &nested_map)?; - }; - } - } - _ => { - return Err(CometError::Internal(format!( - "Unsupported map data type: {:?}", - data_type - ))) - } - } - - Ok(()) -} - -/// Appending the given list stored in `SparkUnsafeArray` into `ListBuilder`. -/// `element_dt` is the data type of the list element. `list_builder` is the list builder. -/// `list` is the list stored in `SparkUnsafeArray`. -pub fn append_list_element( - element_dt: &DataType, - list_builder: &mut ListBuilder>, - list: &SparkUnsafeArray, -) -> Result<(), CometError> { - append_to_builder::(element_dt, list_builder.values(), list)?; - list_builder.append(true); - - Ok(()) -} diff --git a/native/core/src/execution/shuffle/spark_unsafe/map.rs b/native/core/src/execution/shuffle/spark_unsafe/map.rs deleted file mode 100644 index 19b67c43dc..0000000000 --- a/native/core/src/execution/shuffle/spark_unsafe/map.rs +++ /dev/null @@ -1,123 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::{ - errors::CometError, - execution::shuffle::spark_unsafe::list::{append_to_builder, SparkUnsafeArray}, -}; -use arrow::array::builder::{ArrayBuilder, MapBuilder, MapFieldNames}; -use arrow::datatypes::{DataType, FieldRef}; - -pub struct SparkUnsafeMap { - pub(crate) keys: SparkUnsafeArray, - pub(crate) values: SparkUnsafeArray, -} - -impl SparkUnsafeMap { - /// Creates a `SparkUnsafeMap` which points to the given address and size in bytes. - pub(crate) fn new(addr: i64, size: i32) -> Self { - // SAFETY: addr points to valid Spark UnsafeMap data from the JVM. - // The first 8 bytes contain the key array size as a little-endian i64. - debug_assert!(addr != 0, "SparkUnsafeMap::new: null address"); - debug_assert!(size >= 0, "SparkUnsafeMap::new: negative size {size}"); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; - let key_array_size = i64::from_le_bytes(slice.try_into().unwrap()); - - if key_array_size < 0 { - panic!("Negative key size in bytes of map: {key_array_size}"); - } - - if key_array_size > i32::MAX as i64 { - panic!("Number of key size in bytes should <= i32::MAX: {key_array_size}"); - } - - let value_array_size = size - key_array_size as i32 - 8; - if value_array_size < 0 { - panic!("Negative value size in bytes of map: {value_array_size}"); - } - - let keys = SparkUnsafeArray::new(addr + 8); - let values = SparkUnsafeArray::new(addr + 8 + key_array_size); - - if keys.get_num_elements() != values.get_num_elements() { - panic!( - "Number of elements of keys and values should be the same: {} vs {}", - keys.get_num_elements(), - values.get_num_elements() - ); - } - - Self { keys, values } - } -} - -/// Appending the given map stored in `SparkUnsafeMap` into `MapBuilder`. -/// `field` includes data types of the map element. `map_builder` is the map builder. -/// `map` is the map stored in `SparkUnsafeMap`. -pub fn append_map_elements( - field: &FieldRef, - map_builder: &mut MapBuilder, Box>, - map: &SparkUnsafeMap, -) -> Result<(), CometError> { - let (key_field, value_field, _) = get_map_key_value_fields(field)?; - - let keys = &map.keys; - let values = &map.values; - - append_to_builder::(key_field.data_type(), map_builder.keys(), keys)?; - - append_to_builder::(value_field.data_type(), map_builder.values(), values)?; - - map_builder.append(true)?; - - Ok(()) -} - -#[allow(clippy::field_reassign_with_default)] -pub fn get_map_key_value_fields( - field: &FieldRef, -) -> Result<(&FieldRef, &FieldRef, MapFieldNames), CometError> { - let mut map_fieldnames = MapFieldNames::default(); - map_fieldnames.entry = field.name().to_string(); - - let (key_field, value_field) = match field.data_type() { - DataType::Struct(fields) => { - if fields.len() != 2 { - return Err(CometError::Internal(format!( - "Map field should have 2 fields, but got {}", - fields.len() - ))); - } - - let key = &fields[0]; - let value = &fields[1]; - - map_fieldnames.key = key.name().to_string(); - map_fieldnames.value = value.name().to_string(); - - (key, value) - } - _ => { - return Err(CometError::Internal(format!( - "Map field should be a struct, but got {:?}", - field.data_type() - ))); - } - }; - - Ok((key_field, value_field, map_fieldnames)) -} diff --git a/native/core/src/execution/shuffle/spark_unsafe/mod.rs b/native/core/src/execution/shuffle/spark_unsafe/mod.rs deleted file mode 100644 index 6390a0f231..0000000000 --- a/native/core/src/execution/shuffle/spark_unsafe/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -pub mod list; -mod map; -pub mod row; diff --git a/native/core/src/execution/shuffle/spark_unsafe/row.rs b/native/core/src/execution/shuffle/spark_unsafe/row.rs deleted file mode 100644 index 7ebf18d8d0..0000000000 --- a/native/core/src/execution/shuffle/spark_unsafe/row.rs +++ /dev/null @@ -1,1702 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utils for supporting native sort-based columnar shuffle. - -use crate::{ - errors::CometError, - execution::{ - shuffle::{ - codec::{Checksum, ShuffleBlockWriter}, - spark_unsafe::{ - list::{append_list_element, SparkUnsafeArray}, - map::{append_map_elements, get_map_key_value_fields, SparkUnsafeMap}, - }, - }, - utils::bytes_to_i128, - }, -}; -use arrow::array::{ - builder::{ - ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, - Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, - Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, StringDictionaryBuilder, - StructBuilder, TimestampMicrosecondBuilder, - }, - types::Int32Type, - Array, ArrayRef, RecordBatch, RecordBatchOptions, -}; -use arrow::compute::cast; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use arrow::error::ArrowError; -use datafusion::physical_plan::metrics::Time; -use jni::sys::{jint, jlong}; -use std::{ - fs::OpenOptions, - io::{Cursor, Write}, - str::from_utf8, - sync::Arc, -}; - -const MAX_LONG_DIGITS: u8 = 18; -const NESTED_TYPE_BUILDER_CAPACITY: usize = 100; - -/// A common trait for Spark Unsafe classes that can be used to access the underlying data, -/// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to -/// access the underlying data with index. -/// -/// # Safety -/// -/// Implementations must ensure that: -/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory -/// - `get_element_offset()` returns a valid pointer within the row/array data region -/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format -/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership) -/// -/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are -/// safe to call as long as: -/// - The index is within bounds (caller's responsibility) -/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data -/// -/// # Alignment -/// -/// Primitive accessor methods are implemented separately for each type because they have -/// different alignment guarantees: -/// - `SparkUnsafeRow`: All field offsets are 8-byte aligned (bitset width is a multiple of 8, -/// and each field slot is 8 bytes), so accessors use aligned `ptr::read()`. -/// - `SparkUnsafeArray`: The array base address may be unaligned when nested within a row's -/// variable-length region, so accessors use `ptr::read_unaligned()`. -pub trait SparkUnsafeObject { - /// Returns the address of the row. - fn get_row_addr(&self) -> i64; - - /// Returns the offset of the element at the given index. - fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8; - - fn get_boolean(&self, index: usize) -> bool; - fn get_byte(&self, index: usize) -> i8; - fn get_short(&self, index: usize) -> i16; - fn get_int(&self, index: usize) -> i32; - fn get_long(&self, index: usize) -> i64; - fn get_float(&self, index: usize) -> f32; - fn get_double(&self, index: usize) -> f64; - fn get_date(&self, index: usize) -> i32; - fn get_timestamp(&self, index: usize) -> i64; - - /// Returns the offset and length of the element at the given index. - #[inline] - fn get_offset_and_len(&self, index: usize) -> (i32, i32) { - let offset_and_size = self.get_long(index); - let offset = (offset_and_size >> 32) as i32; - let len = offset_and_size as i32; - (offset, len) - } - - /// Returns string value at the given index of the object. - fn get_string(&self, index: usize) -> &str { - let (offset, len) = self.get_offset_and_len(index); - let addr = self.get_row_addr() + offset as i64; - // SAFETY: addr points to valid UTF-8 string data within the variable-length region. - // Offset and length are read from the fixed-length portion of the row/array. - debug_assert!(addr != 0, "get_string: null address at index {index}"); - debug_assert!( - len >= 0, - "get_string: negative length {len} at index {index}" - ); - let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; - - from_utf8(slice).unwrap() - } - - /// Returns binary value at the given index of the object. - fn get_binary(&self, index: usize) -> &[u8] { - let (offset, len) = self.get_offset_and_len(index); - let addr = self.get_row_addr() + offset as i64; - // SAFETY: addr points to valid binary data within the variable-length region. - // Offset and length are read from the fixed-length portion of the row/array. - debug_assert!(addr != 0, "get_binary: null address at index {index}"); - debug_assert!( - len >= 0, - "get_binary: negative length {len} at index {index}" - ); - unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } - } - - /// Returns decimal value at the given index of the object. - fn get_decimal(&self, index: usize, precision: u8) -> i128 { - if precision <= MAX_LONG_DIGITS { - self.get_long(index) as i128 - } else { - let slice = self.get_binary(index); - bytes_to_i128(slice) - } - } - - /// Returns struct value at the given index of the object. - fn get_struct(&self, index: usize, num_fields: usize) -> SparkUnsafeRow { - let (offset, len) = self.get_offset_and_len(index); - let mut row = SparkUnsafeRow::new_with_num_fields(num_fields); - row.point_to(self.get_row_addr() + offset as i64, len); - - row - } - - /// Returns array value at the given index of the object. - fn get_array(&self, index: usize) -> SparkUnsafeArray { - let (offset, _) = self.get_offset_and_len(index); - SparkUnsafeArray::new(self.get_row_addr() + offset as i64) - } - - fn get_map(&self, index: usize) -> SparkUnsafeMap { - let (offset, len) = self.get_offset_and_len(index); - SparkUnsafeMap::new(self.get_row_addr() + offset as i64, len) - } -} - -/// Generates primitive accessor implementations for `SparkUnsafeObject`. -/// -/// Uses `$read_method` to read typed values from raw pointers: -/// - `read` for aligned access (SparkUnsafeRow — all offsets are 8-byte aligned) -/// - `read_unaligned` for potentially unaligned access (SparkUnsafeArray) -macro_rules! impl_primitive_accessors { - ($read_method:ident) => { - #[inline] - fn get_boolean(&self, index: usize) -> bool { - let addr = self.get_element_offset(index, 1); - debug_assert!( - !addr.is_null(), - "get_boolean: null pointer at index {index}" - ); - // SAFETY: addr points to valid element data within the row/array region. - unsafe { *addr != 0 } - } - - #[inline] - fn get_byte(&self, index: usize) -> i8 { - let addr = self.get_element_offset(index, 1); - debug_assert!(!addr.is_null(), "get_byte: null pointer at index {index}"); - // SAFETY: addr points to valid element data (1 byte) within the row/array region. - unsafe { *(addr as *const i8) } - } - - #[inline] - fn get_short(&self, index: usize) -> i16 { - let addr = self.get_element_offset(index, 2) as *const i16; - debug_assert!(!addr.is_null(), "get_short: null pointer at index {index}"); - // SAFETY: addr points to valid element data (2 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_int(&self, index: usize) -> i32 { - let addr = self.get_element_offset(index, 4) as *const i32; - debug_assert!(!addr.is_null(), "get_int: null pointer at index {index}"); - // SAFETY: addr points to valid element data (4 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_long(&self, index: usize) -> i64 { - let addr = self.get_element_offset(index, 8) as *const i64; - debug_assert!(!addr.is_null(), "get_long: null pointer at index {index}"); - // SAFETY: addr points to valid element data (8 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_float(&self, index: usize) -> f32 { - let addr = self.get_element_offset(index, 4) as *const f32; - debug_assert!(!addr.is_null(), "get_float: null pointer at index {index}"); - // SAFETY: addr points to valid element data (4 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_double(&self, index: usize) -> f64 { - let addr = self.get_element_offset(index, 8) as *const f64; - debug_assert!(!addr.is_null(), "get_double: null pointer at index {index}"); - // SAFETY: addr points to valid element data (8 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_date(&self, index: usize) -> i32 { - let addr = self.get_element_offset(index, 4) as *const i32; - debug_assert!(!addr.is_null(), "get_date: null pointer at index {index}"); - // SAFETY: addr points to valid element data (4 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - - #[inline] - fn get_timestamp(&self, index: usize) -> i64 { - let addr = self.get_element_offset(index, 8) as *const i64; - debug_assert!( - !addr.is_null(), - "get_timestamp: null pointer at index {index}" - ); - // SAFETY: addr points to valid element data (8 bytes) within the row/array region. - unsafe { addr.$read_method() } - } - }; -} -pub(crate) use impl_primitive_accessors; - -pub struct SparkUnsafeRow { - row_addr: i64, - row_size: i32, - row_bitset_width: i64, -} - -impl SparkUnsafeObject for SparkUnsafeRow { - fn get_row_addr(&self) -> i64 { - self.row_addr - } - - fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 { - let offset = self.row_bitset_width + (index * 8) as i64; - debug_assert!( - self.row_size >= 0 && offset + element_size as i64 <= self.row_size as i64, - "get_element_offset: access at offset {offset} with size {element_size} \ - exceeds row_size {} for index {index}", - self.row_size - ); - (self.row_addr + offset) as *const u8 - } - - // SparkUnsafeRow field offsets are always 8-byte aligned: the base address is 8-byte - // aligned (JVM guarantee), bitset_width is a multiple of 8, and each field slot is - // 8 bytes. This means we can safely use aligned ptr::read() for all typed accesses. - impl_primitive_accessors!(read); -} - -impl Default for SparkUnsafeRow { - fn default() -> Self { - Self { - row_addr: -1, - row_size: -1, - row_bitset_width: -1, - } - } -} - -impl SparkUnsafeRow { - fn new(schema: &[DataType]) -> Self { - Self { - row_addr: -1, - row_size: -1, - row_bitset_width: Self::get_row_bitset_width(schema.len()) as i64, - } - } - - /// Returns true if the row is a null row. - pub fn is_null_row(&self) -> bool { - self.row_addr == -1 && self.row_size == -1 && self.row_bitset_width == -1 - } - - /// Calculate the width of the bitset for the row in bytes. - /// The logic is from Spark `UnsafeRow.calculateBitSetWidthInBytes`. - #[inline] - pub const fn get_row_bitset_width(num_fields: usize) -> usize { - num_fields.div_ceil(64) * 8 - } - - pub fn new_with_num_fields(num_fields: usize) -> Self { - Self { - row_addr: -1, - row_size: -1, - row_bitset_width: Self::get_row_bitset_width(num_fields) as i64, - } - } - - /// Points the row to the given slice. - pub fn point_to_slice(&mut self, slice: &[u8]) { - self.row_addr = slice.as_ptr() as i64; - self.row_size = slice.len() as i32; - } - - /// Points the row to the given address with specified row size. - fn point_to(&mut self, row_addr: i64, row_size: i32) { - self.row_addr = row_addr; - self.row_size = row_size; - } - - pub fn get_row_size(&self) -> i32 { - self.row_size - } - - /// Returns true if the null bit at the given index of the row is set. - #[inline] - pub(crate) fn is_null_at(&self, index: usize) -> bool { - // SAFETY: row_addr points to valid Spark UnsafeRow data with at least - // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. - // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. - // The bitset starts at row_addr (8-byte aligned) and each word is at offset 8*k, - // so word_offset is always 8-byte aligned — we can use aligned ptr::read(). - debug_assert!(self.row_addr != -1, "is_null_at: row not initialized"); - unsafe { - let mask: i64 = 1i64 << (index & 0x3f); - let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; - let word: i64 = word_offset.read(); - (word & mask) != 0 - } - } - - /// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null). - pub fn set_not_null_at(&mut self, index: usize) { - // SAFETY: row_addr points to valid Spark UnsafeRow data with at least - // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. - // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. - // Writing is safe because we have mutable access and the memory is owned by the JVM. - // The bitset is always 8-byte aligned — we can use aligned ptr::read()/write(). - debug_assert!(self.row_addr != -1, "set_not_null_at: row not initialized"); - unsafe { - let mask: i64 = 1i64 << (index & 0x3f); - let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; - let word: i64 = word_offset.read(); - word_offset.write(word & !mask); - } - } -} - -macro_rules! downcast_builder_ref { - ($builder_type:ty, $builder:expr) => {{ - let actual_type_id = $builder.as_any().type_id(); - $builder - .as_any_mut() - .downcast_mut::<$builder_type>() - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to downcast builder: expected {}, got {:?}", - stringify!($builder_type), - actual_type_id - )) - })? - }}; -} - -macro_rules! get_field_builder { - ($struct_builder:expr, $builder_type:ty, $idx:expr) => { - $struct_builder - .field_builder::<$builder_type>($idx) - .ok_or_else(|| { - CometError::Internal(format!( - "Failed to get field builder at index {}: expected {}", - $idx, - stringify!($builder_type) - )) - })? - }; -} - -// Expose the macro for other modules. -use crate::execution::shuffle::CompressionCodec; -pub(crate) use downcast_builder_ref; - -/// Appends field of row to the given struct builder. `dt` is the data type of the field. -/// `struct_builder` is the struct builder of the row. `row` is the row that contains the field. -/// `idx` is the index of the field in the row. The caller is responsible for ensuring that the -/// `struct_builder.append` is called before/after calling this function to append the null buffer -/// of the struct array. -#[allow(clippy::redundant_closure_call)] -pub(super) fn append_field( - dt: &DataType, - struct_builder: &mut StructBuilder, - row: &SparkUnsafeRow, - idx: usize, -) -> Result<(), CometError> { - /// A macro for generating code of appending value into field builder of Arrow struct builder. - macro_rules! append_field_to_builder { - ($builder_type:ty, $accessor:expr) => {{ - let field_builder = get_field_builder!(struct_builder, $builder_type, idx); - - if row.is_null_row() { - // The row is null. - field_builder.append_null(); - } else { - let is_null = row.is_null_at(idx); - - if is_null { - // The field in the row is null. - // Append a null value to the field builder. - field_builder.append_null(); - } else { - $accessor(field_builder); - } - } - }}; - } - - match dt { - DataType::Boolean => { - append_field_to_builder!(BooleanBuilder, |builder: &mut BooleanBuilder| builder - .append_value(row.get_boolean(idx))); - } - DataType::Int8 => { - append_field_to_builder!(Int8Builder, |builder: &mut Int8Builder| builder - .append_value(row.get_byte(idx))); - } - DataType::Int16 => { - append_field_to_builder!(Int16Builder, |builder: &mut Int16Builder| builder - .append_value(row.get_short(idx))); - } - DataType::Int32 => { - append_field_to_builder!(Int32Builder, |builder: &mut Int32Builder| builder - .append_value(row.get_int(idx))); - } - DataType::Int64 => { - append_field_to_builder!(Int64Builder, |builder: &mut Int64Builder| builder - .append_value(row.get_long(idx))); - } - DataType::Float32 => { - append_field_to_builder!(Float32Builder, |builder: &mut Float32Builder| builder - .append_value(row.get_float(idx))); - } - DataType::Float64 => { - append_field_to_builder!(Float64Builder, |builder: &mut Float64Builder| builder - .append_value(row.get_double(idx))); - } - DataType::Date32 => { - append_field_to_builder!(Date32Builder, |builder: &mut Date32Builder| builder - .append_value(row.get_date(idx))); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - append_field_to_builder!( - TimestampMicrosecondBuilder, - |builder: &mut TimestampMicrosecondBuilder| builder - .append_value(row.get_timestamp(idx)) - ); - } - DataType::Binary => { - append_field_to_builder!(BinaryBuilder, |builder: &mut BinaryBuilder| builder - .append_value(row.get_binary(idx))); - } - DataType::Utf8 => { - append_field_to_builder!(StringBuilder, |builder: &mut StringBuilder| builder - .append_value(row.get_string(idx))); - } - DataType::Decimal128(p, _) => { - append_field_to_builder!(Decimal128Builder, |builder: &mut Decimal128Builder| builder - .append_value(row.get_decimal(idx, *p))); - } - DataType::Struct(fields) => { - // Appending value into struct field builder of Arrow struct builder. - let field_builder = get_field_builder!(struct_builder, StructBuilder, idx); - - let nested_row = if row.is_null_row() || row.is_null_at(idx) { - // The row is null, or the field in the row is null, i.e., a null nested row. - // Append a null value to the row builder. - field_builder.append_null(); - SparkUnsafeRow::default() - } else { - field_builder.append(true); - row.get_struct(idx, fields.len()) - }; - - for (field_idx, field) in fields.into_iter().enumerate() { - append_field(field.data_type(), field_builder, &nested_row, field_idx)?; - } - } - DataType::Map(field, _) => { - let field_builder = get_field_builder!( - struct_builder, - MapBuilder, Box>, - idx - ); - - if row.is_null_row() { - // The row is null. - field_builder.append(false)?; - } else { - let is_null = row.is_null_at(idx); - - if is_null { - // The field in the row is null. - // Append a null value to the map builder. - field_builder.append(false)?; - } else { - append_map_elements(field, field_builder, &row.get_map(idx))?; - } - } - } - DataType::List(field) => { - let field_builder = - get_field_builder!(struct_builder, ListBuilder>, idx); - - if row.is_null_row() { - // The row is null. - field_builder.append_null(); - } else { - let is_null = row.is_null_at(idx); - - if is_null { - // The field in the row is null. - // Append a null value to the list builder. - field_builder.append_null(); - } else { - append_list_element(field.data_type(), field_builder, &row.get_array(idx))? - } - } - } - _ => { - unreachable!("Unsupported data type of struct field: {:?}", dt) - } - } - - Ok(()) -} - -/// Appends nested struct fields to the struct builder using field-major order. -/// This is a helper function for processing nested struct fields recursively. -/// -/// Unlike `append_struct_fields_field_major`, this function takes slices of row addresses, -/// sizes, and null flags directly, without needing to navigate from a parent row. -#[allow(clippy::redundant_closure_call)] -fn append_nested_struct_fields_field_major( - row_addresses: &[jlong], - row_sizes: &[jint], - struct_is_null: &[bool], - struct_builder: &mut StructBuilder, - fields: &arrow::datatypes::Fields, -) -> Result<(), CometError> { - let num_rows = row_addresses.len(); - let mut row = SparkUnsafeRow::new_with_num_fields(fields.len()); - - // Helper macro for processing primitive fields - macro_rules! process_field { - ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{ - let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx); - - for row_idx in 0..num_rows { - if struct_is_null[row_idx] { - // Struct is null, field is also null - field_builder.append_null(); - } else { - let row_addr = row_addresses[row_idx]; - let row_size = row_sizes[row_idx]; - row.point_to(row_addr, row_size); - - if row.is_null_at($field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value($get_value(&row, $field_idx)); - } - } - } - }}; - } - - // Process each field across all rows - for (field_idx, field) in fields.iter().enumerate() { - match field.data_type() { - DataType::Boolean => { - process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_boolean(idx)); - } - DataType::Int8 => { - process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_byte(idx)); - } - DataType::Int16 => { - process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_short(idx)); - } - DataType::Int32 => { - process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_int(idx)); - } - DataType::Int64 => { - process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_long(idx)); - } - DataType::Float32 => { - process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_float(idx)); - } - DataType::Float64 => { - process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_double(idx)); - } - DataType::Date32 => { - process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_date(idx)); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - process_field!( - TimestampMicrosecondBuilder, - field_idx, - |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) - ); - } - DataType::Binary => { - let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); - - for row_idx in 0..num_rows { - if struct_is_null[row_idx] { - field_builder.append_null(); - } else { - let row_addr = row_addresses[row_idx]; - let row_size = row_sizes[row_idx]; - row.point_to(row_addr, row_size); - - if row.is_null_at(field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value(row.get_binary(field_idx)); - } - } - } - } - DataType::Utf8 => { - let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx); - - for row_idx in 0..num_rows { - if struct_is_null[row_idx] { - field_builder.append_null(); - } else { - let row_addr = row_addresses[row_idx]; - let row_size = row_sizes[row_idx]; - row.point_to(row_addr, row_size); - - if row.is_null_at(field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value(row.get_string(field_idx)); - } - } - } - } - DataType::Decimal128(p, _) => { - let p = *p; - let field_builder = - get_field_builder!(struct_builder, Decimal128Builder, field_idx); - - for row_idx in 0..num_rows { - if struct_is_null[row_idx] { - field_builder.append_null(); - } else { - let row_addr = row_addresses[row_idx]; - let row_size = row_sizes[row_idx]; - row.point_to(row_addr, row_size); - - if row.is_null_at(field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value(row.get_decimal(field_idx, p)); - } - } - } - } - DataType::Struct(nested_fields) => { - let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx); - - // Collect nested struct addresses and sizes in one pass, building validity - let mut nested_addresses: Vec = Vec::with_capacity(num_rows); - let mut nested_sizes: Vec = Vec::with_capacity(num_rows); - let mut nested_is_null: Vec = Vec::with_capacity(num_rows); - - for row_idx in 0..num_rows { - if struct_is_null[row_idx] { - // Parent struct is null, nested struct is also null - nested_builder.append_null(); - nested_is_null.push(true); - nested_addresses.push(0); - nested_sizes.push(0); - } else { - let row_addr = row_addresses[row_idx]; - let row_size = row_sizes[row_idx]; - row.point_to(row_addr, row_size); - - if row.is_null_at(field_idx) { - nested_builder.append_null(); - nested_is_null.push(true); - nested_addresses.push(0); - nested_sizes.push(0); - } else { - nested_builder.append(true); - nested_is_null.push(false); - // Get nested struct address and size - let nested_row = row.get_struct(field_idx, nested_fields.len()); - nested_addresses.push(nested_row.get_row_addr()); - nested_sizes.push(nested_row.get_row_size()); - } - } - } - - // Recursively process nested struct fields in field-major order - append_nested_struct_fields_field_major( - &nested_addresses, - &nested_sizes, - &nested_is_null, - nested_builder, - nested_fields, - )?; - } - // For list and map, fall back to append_field since they have variable-length elements - dt @ (DataType::List(_) | DataType::Map(_, _)) => { - for row_idx in 0..num_rows { - if struct_is_null[row_idx] { - let null_row = SparkUnsafeRow::default(); - append_field(dt, struct_builder, &null_row, field_idx)?; - } else { - let row_addr = row_addresses[row_idx]; - let row_size = row_sizes[row_idx]; - row.point_to(row_addr, row_size); - append_field(dt, struct_builder, &row, field_idx)?; - } - } - } - _ => { - unreachable!( - "Unsupported data type of struct field: {:?}", - field.data_type() - ) - } - } - } - - Ok(()) -} - -/// Reads row address and size from JVM-provided pointer arrays and points the row to that data. -/// -/// # Safety -/// Caller must ensure row_addresses_ptr and row_sizes_ptr are valid for index i. -/// This is guaranteed when called from append_columns with indices in [row_start, row_end). -macro_rules! read_row_at { - ($row:expr, $row_addresses_ptr:expr, $row_sizes_ptr:expr, $i:expr) => {{ - // SAFETY: Caller guarantees pointers are valid for this index (see macro doc) - debug_assert!( - !$row_addresses_ptr.is_null(), - "read_row_at: null row_addresses_ptr" - ); - debug_assert!(!$row_sizes_ptr.is_null(), "read_row_at: null row_sizes_ptr"); - let row_addr = unsafe { *$row_addresses_ptr.add($i) }; - let row_size = unsafe { *$row_sizes_ptr.add($i) }; - $row.point_to(row_addr, row_size); - }}; -} - -/// Appends a batch of list values to the list builder with a single type dispatch. -/// This moves type dispatch from O(rows) to O(1), significantly improving performance -/// for large batches. -#[allow(clippy::too_many_arguments)] -fn append_list_column_batch( - row_addresses_ptr: *mut jlong, - row_sizes_ptr: *mut jint, - row_start: usize, - row_end: usize, - schema: &[DataType], - column_idx: usize, - element_type: &DataType, - list_builder: &mut ListBuilder>, -) -> Result<(), CometError> { - let mut row = SparkUnsafeRow::new(schema); - - // Helper macro for primitive element types - gets builder fresh each iteration - // to avoid borrow conflicts with list_builder.append() - macro_rules! process_primitive_lists { - ($builder_type:ty, $append_fn:ident) => {{ - for i in row_start..row_end { - read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); - - if row.is_null_at(column_idx) { - list_builder.append_null(); - } else { - let array = row.get_array(column_idx); - // Get values builder fresh each iteration to avoid borrow conflict - let values_builder = list_builder - .values() - .as_any_mut() - .downcast_mut::<$builder_type>() - .expect(stringify!($builder_type)); - array.$append_fn::(values_builder); - list_builder.append(true); - } - } - }}; - } - - match element_type { - DataType::Boolean => { - process_primitive_lists!(BooleanBuilder, append_booleans_to_builder); - } - DataType::Int8 => { - process_primitive_lists!(Int8Builder, append_bytes_to_builder); - } - DataType::Int16 => { - process_primitive_lists!(Int16Builder, append_shorts_to_builder); - } - DataType::Int32 => { - process_primitive_lists!(Int32Builder, append_ints_to_builder); - } - DataType::Int64 => { - process_primitive_lists!(Int64Builder, append_longs_to_builder); - } - DataType::Float32 => { - process_primitive_lists!(Float32Builder, append_floats_to_builder); - } - DataType::Float64 => { - process_primitive_lists!(Float64Builder, append_doubles_to_builder); - } - DataType::Date32 => { - process_primitive_lists!(Date32Builder, append_dates_to_builder); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - process_primitive_lists!(TimestampMicrosecondBuilder, append_timestamps_to_builder); - } - // For complex element types, fall back to per-row dispatch - _ => { - for i in row_start..row_end { - read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); - - if row.is_null_at(column_idx) { - list_builder.append_null(); - } else { - append_list_element(element_type, list_builder, &row.get_array(column_idx))?; - } - } - } - } - - Ok(()) -} - -/// Appends a batch of map values to the map builder with a single type dispatch. -/// This moves type dispatch from O(rows × 2) to O(2), improving performance for maps. -#[allow(clippy::too_many_arguments)] -fn append_map_column_batch( - row_addresses_ptr: *mut jlong, - row_sizes_ptr: *mut jint, - row_start: usize, - row_end: usize, - schema: &[DataType], - column_idx: usize, - field: &arrow::datatypes::FieldRef, - map_builder: &mut MapBuilder, Box>, -) -> Result<(), CometError> { - let mut row = SparkUnsafeRow::new(schema); - let (key_field, value_field, _) = get_map_key_value_fields(field)?; - let key_type = key_field.data_type(); - let value_type = value_field.data_type(); - - // Helper macro for processing maps with primitive key/value types - // Uses scoped borrows to avoid borrow checker conflicts - macro_rules! process_primitive_maps { - ($key_builder:ty, $key_append:ident, $val_builder:ty, $val_append:ident) => {{ - for i in row_start..row_end { - read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); - - if row.is_null_at(column_idx) { - map_builder.append(false)?; - } else { - let map = row.get_map(column_idx); - // Process keys in a scope so borrow ends - { - let keys_builder = map_builder - .keys() - .as_any_mut() - .downcast_mut::<$key_builder>() - .expect(stringify!($key_builder)); - map.keys.$key_append::(keys_builder); - } - // Process values in a scope so borrow ends - { - let values_builder = map_builder - .values() - .as_any_mut() - .downcast_mut::<$val_builder>() - .expect(stringify!($val_builder)); - map.values.$val_append::(values_builder); - } - map_builder.append(true)?; - } - } - }}; - } - - // Optimize common map type combinations - match (key_type, value_type) { - // Map - (DataType::Int64, DataType::Int64) => { - process_primitive_maps!( - Int64Builder, - append_longs_to_builder, - Int64Builder, - append_longs_to_builder - ); - } - // Map - (DataType::Int64, DataType::Float64) => { - process_primitive_maps!( - Int64Builder, - append_longs_to_builder, - Float64Builder, - append_doubles_to_builder - ); - } - // Map - (DataType::Int32, DataType::Int32) => { - process_primitive_maps!( - Int32Builder, - append_ints_to_builder, - Int32Builder, - append_ints_to_builder - ); - } - // Map - (DataType::Int32, DataType::Int64) => { - process_primitive_maps!( - Int32Builder, - append_ints_to_builder, - Int64Builder, - append_longs_to_builder - ); - } - // For other types, fall back to per-row dispatch - _ => { - for i in row_start..row_end { - read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); - - if row.is_null_at(column_idx) { - map_builder.append(false)?; - } else { - append_map_elements(field, map_builder, &row.get_map(column_idx))?; - } - } - } - } - - Ok(()) -} - -/// Appends struct fields to the struct builder using field-major order. -/// This processes one field at a time across all rows, which moves type dispatch -/// outside the row loop (O(fields) dispatches instead of O(rows × fields)). -#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] -fn append_struct_fields_field_major( - row_addresses_ptr: *mut jlong, - row_sizes_ptr: *mut jint, - row_start: usize, - row_end: usize, - parent_row: &mut SparkUnsafeRow, - column_idx: usize, - struct_builder: &mut StructBuilder, - fields: &arrow::datatypes::Fields, -) -> Result<(), CometError> { - let num_rows = row_end - row_start; - let num_fields = fields.len(); - - // First pass: Build struct validity and collect which structs are null - // We use a Vec for simplicity; could use a bitset for better memory - let mut struct_is_null = Vec::with_capacity(num_rows); - - for i in row_start..row_end { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - - let is_null = parent_row.is_null_at(column_idx); - struct_is_null.push(is_null); - - if is_null { - struct_builder.append_null(); - } else { - struct_builder.append(true); - } - } - - // Helper macro for processing primitive fields - macro_rules! process_field { - ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{ - let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx); - - for (row_idx, i) in (row_start..row_end).enumerate() { - if struct_is_null[row_idx] { - // Struct is null, field is also null - field_builder.append_null(); - } else { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - let nested_row = parent_row.get_struct(column_idx, num_fields); - - if nested_row.is_null_at($field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value($get_value(&nested_row, $field_idx)); - } - } - } - }}; - } - - // Second pass: Process each field across all rows - for (field_idx, field) in fields.iter().enumerate() { - match field.data_type() { - DataType::Boolean => { - process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_boolean(idx)); - } - DataType::Int8 => { - process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_byte(idx)); - } - DataType::Int16 => { - process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_short(idx)); - } - DataType::Int32 => { - process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_int(idx)); - } - DataType::Int64 => { - process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_long(idx)); - } - DataType::Float32 => { - process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_float(idx)); - } - DataType::Float64 => { - process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_double(idx)); - } - DataType::Date32 => { - process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row - .get_date(idx)); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - process_field!( - TimestampMicrosecondBuilder, - field_idx, - |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) - ); - } - DataType::Binary => { - let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); - - for (row_idx, i) in (row_start..row_end).enumerate() { - if struct_is_null[row_idx] { - field_builder.append_null(); - } else { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - let nested_row = parent_row.get_struct(column_idx, num_fields); - - if nested_row.is_null_at(field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value(nested_row.get_binary(field_idx)); - } - } - } - } - DataType::Utf8 => { - let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx); - - for (row_idx, i) in (row_start..row_end).enumerate() { - if struct_is_null[row_idx] { - field_builder.append_null(); - } else { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - let nested_row = parent_row.get_struct(column_idx, num_fields); - - if nested_row.is_null_at(field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value(nested_row.get_string(field_idx)); - } - } - } - } - DataType::Decimal128(p, _) => { - let p = *p; - let field_builder = - get_field_builder!(struct_builder, Decimal128Builder, field_idx); - - for (row_idx, i) in (row_start..row_end).enumerate() { - if struct_is_null[row_idx] { - field_builder.append_null(); - } else { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - let nested_row = parent_row.get_struct(column_idx, num_fields); - - if nested_row.is_null_at(field_idx) { - field_builder.append_null(); - } else { - field_builder.append_value(nested_row.get_decimal(field_idx, p)); - } - } - } - } - // For nested structs, apply field-major processing recursively - DataType::Struct(nested_fields) => { - let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx); - - // Collect nested struct addresses and sizes in one pass, building validity - let mut nested_addresses: Vec = Vec::with_capacity(num_rows); - let mut nested_sizes: Vec = Vec::with_capacity(num_rows); - let mut nested_is_null: Vec = Vec::with_capacity(num_rows); - - for (row_idx, i) in (row_start..row_end).enumerate() { - if struct_is_null[row_idx] { - // Parent struct is null, nested struct is also null - nested_builder.append_null(); - nested_is_null.push(true); - nested_addresses.push(0); - nested_sizes.push(0); - } else { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - let parent_struct = parent_row.get_struct(column_idx, num_fields); - - if parent_struct.is_null_at(field_idx) { - nested_builder.append_null(); - nested_is_null.push(true); - nested_addresses.push(0); - nested_sizes.push(0); - } else { - nested_builder.append(true); - nested_is_null.push(false); - // Get nested struct address and size - let nested_row = - parent_struct.get_struct(field_idx, nested_fields.len()); - nested_addresses.push(nested_row.get_row_addr()); - nested_sizes.push(nested_row.get_row_size()); - } - } - } - - // Recursively process nested struct fields in field-major order - append_nested_struct_fields_field_major( - &nested_addresses, - &nested_sizes, - &nested_is_null, - nested_builder, - nested_fields, - )?; - } - // For list and map, fall back to append_field since they have variable-length elements - dt @ (DataType::List(_) | DataType::Map(_, _)) => { - for (row_idx, i) in (row_start..row_end).enumerate() { - if struct_is_null[row_idx] { - let null_row = SparkUnsafeRow::default(); - append_field(dt, struct_builder, &null_row, field_idx)?; - } else { - read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); - let nested_row = parent_row.get_struct(column_idx, num_fields); - append_field(dt, struct_builder, &nested_row, field_idx)?; - } - } - } - _ => { - unreachable!( - "Unsupported data type of struct field: {:?}", - field.data_type() - ) - } - } - } - - Ok(()) -} - -/// Appends column of top rows to the given array builder. -/// -/// # Safety -/// -/// The caller must ensure: -/// - `row_addresses_ptr` points to an array of at least `row_end` jlong values -/// - `row_sizes_ptr` points to an array of at least `row_end` jint values -/// - Each address in `row_addresses_ptr[row_start..row_end]` points to valid Spark UnsafeRow data -/// - The memory remains valid for the duration of this function call -/// -/// These invariants are guaranteed when called from JNI with arrays provided by the JVM. -#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] -fn append_columns( - row_addresses_ptr: *mut jlong, - row_sizes_ptr: *mut jint, - row_start: usize, - row_end: usize, - schema: &[DataType], - column_idx: usize, - builder: &mut Box, - prefer_dictionary_ratio: f64, -) -> Result<(), CometError> { - /// A macro for generating code of appending values into Arrow array builders. - macro_rules! append_column_to_builder { - ($builder_type:ty, $accessor:expr) => {{ - let element_builder = builder - .as_any_mut() - .downcast_mut::<$builder_type>() - .expect(stringify!($builder_type)); - let mut row = SparkUnsafeRow::new(schema); - - for i in row_start..row_end { - read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); - - let is_null = row.is_null_at(column_idx); - - if is_null { - // The element value is null. - // Append a null value to the element builder. - element_builder.append_null(); - } else { - $accessor(element_builder, &row, column_idx); - } - } - }}; - } - - let dt = &schema[column_idx]; - - match dt { - DataType::Boolean => { - append_column_to_builder!( - BooleanBuilder, - |builder: &mut BooleanBuilder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_boolean(idx)) - ); - } - DataType::Int8 => { - append_column_to_builder!( - Int8Builder, - |builder: &mut Int8Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_byte(idx)) - ); - } - DataType::Int16 => { - append_column_to_builder!( - Int16Builder, - |builder: &mut Int16Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_short(idx)) - ); - } - DataType::Int32 => { - append_column_to_builder!( - Int32Builder, - |builder: &mut Int32Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_int(idx)) - ); - } - DataType::Int64 => { - append_column_to_builder!( - Int64Builder, - |builder: &mut Int64Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_long(idx)) - ); - } - DataType::Float32 => { - append_column_to_builder!( - Float32Builder, - |builder: &mut Float32Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_float(idx)) - ); - } - DataType::Float64 => { - append_column_to_builder!( - Float64Builder, - |builder: &mut Float64Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_double(idx)) - ); - } - DataType::Decimal128(p, _) => { - append_column_to_builder!( - Decimal128Builder, - |builder: &mut Decimal128Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_decimal(idx, *p)) - ); - } - DataType::Utf8 => { - if prefer_dictionary_ratio > 1.0 { - append_column_to_builder!( - StringDictionaryBuilder, - |builder: &mut StringDictionaryBuilder, - row: &SparkUnsafeRow, - idx| builder.append_value(row.get_string(idx)) - ); - } else { - append_column_to_builder!( - StringBuilder, - |builder: &mut StringBuilder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_string(idx)) - ); - } - } - DataType::Binary => { - if prefer_dictionary_ratio > 1.0 { - append_column_to_builder!( - BinaryDictionaryBuilder, - |builder: &mut BinaryDictionaryBuilder, - row: &SparkUnsafeRow, - idx| builder.append_value(row.get_binary(idx)) - ); - } else { - append_column_to_builder!( - BinaryBuilder, - |builder: &mut BinaryBuilder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_binary(idx)) - ); - } - } - DataType::Date32 => { - append_column_to_builder!( - Date32Builder, - |builder: &mut Date32Builder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_date(idx)) - ); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - append_column_to_builder!( - TimestampMicrosecondBuilder, - |builder: &mut TimestampMicrosecondBuilder, row: &SparkUnsafeRow, idx| builder - .append_value(row.get_timestamp(idx)) - ); - } - DataType::Map(field, _) => { - let map_builder = downcast_builder_ref!( - MapBuilder, Box>, - builder - ); - // Use batched processing for better performance - append_map_column_batch( - row_addresses_ptr, - row_sizes_ptr, - row_start, - row_end, - schema, - column_idx, - field, - map_builder, - )?; - } - DataType::List(field) => { - let list_builder = downcast_builder_ref!(ListBuilder>, builder); - // Use batched processing for better performance - append_list_column_batch( - row_addresses_ptr, - row_sizes_ptr, - row_start, - row_end, - schema, - column_idx, - field.data_type(), - list_builder, - )?; - } - DataType::Struct(fields) => { - let struct_builder = builder - .as_any_mut() - .downcast_mut::() - .expect("StructBuilder"); - let mut row = SparkUnsafeRow::new(schema); - - // Use field-major processing to avoid per-row type dispatch - append_struct_fields_field_major( - row_addresses_ptr, - row_sizes_ptr, - row_start, - row_end, - &mut row, - column_idx, - struct_builder, - fields, - )?; - } - _ => { - unreachable!("Unsupported data type of column: {:?}", dt) - } - } - - Ok(()) -} - -fn make_builders( - dt: &DataType, - row_num: usize, - prefer_dictionary_ratio: f64, -) -> Result, CometError> { - let builder: Box = match dt { - DataType::Boolean => Box::new(BooleanBuilder::with_capacity(row_num)), - DataType::Int8 => Box::new(Int8Builder::with_capacity(row_num)), - DataType::Int16 => Box::new(Int16Builder::with_capacity(row_num)), - DataType::Int32 => Box::new(Int32Builder::with_capacity(row_num)), - DataType::Int64 => Box::new(Int64Builder::with_capacity(row_num)), - DataType::Float32 => Box::new(Float32Builder::with_capacity(row_num)), - DataType::Float64 => Box::new(Float64Builder::with_capacity(row_num)), - DataType::Decimal128(_, _) => { - Box::new(Decimal128Builder::with_capacity(row_num).with_data_type(dt.clone())) - } - DataType::Utf8 => { - if prefer_dictionary_ratio > 1.0 { - Box::new(StringDictionaryBuilder::::with_capacity( - row_num / 2, - row_num, - 1024, - )) - } else { - Box::new(StringBuilder::with_capacity(row_num, 1024)) - } - } - DataType::Binary => { - if prefer_dictionary_ratio > 1.0 { - Box::new(BinaryDictionaryBuilder::::with_capacity( - row_num / 2, - row_num, - 1024, - )) - } else { - Box::new(BinaryBuilder::with_capacity(row_num, 1024)) - } - } - DataType::Date32 => Box::new(Date32Builder::with_capacity(row_num)), - DataType::Timestamp(TimeUnit::Microsecond, _) => { - Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone())) - } - DataType::Map(field, _) => { - let (key_field, value_field, map_field_names) = get_map_key_value_fields(field)?; - let key_dt = key_field.data_type(); - let value_dt = value_field.data_type(); - let key_builder = make_builders(key_dt, NESTED_TYPE_BUILDER_CAPACITY, 1.0)?; - let value_builder = make_builders(value_dt, NESTED_TYPE_BUILDER_CAPACITY, 1.0)?; - - Box::new( - MapBuilder::new(Some(map_field_names), key_builder, value_builder) - .with_values_field(Arc::clone(value_field)), - ) - } - DataType::List(field) => { - // Disable dictionary encoding for array element - let value_builder = - make_builders(field.data_type(), NESTED_TYPE_BUILDER_CAPACITY, 1.0)?; - - // Needed to overwrite default ListBuilder creation having the incoming field schema to be driving - let value_field = Arc::clone(field); - - Box::new(ListBuilder::new(value_builder).with_field(value_field)) - } - DataType::Struct(fields) => { - let field_builders = fields - .iter() - // Disable dictionary encoding for struct fields - .map(|field| make_builders(field.data_type(), row_num, 1.0)) - .collect::, _>>()?; - - Box::new(StructBuilder::new(fields.clone(), field_builders)) - } - _ => return Err(CometError::Internal(format!("Unsupported type: {dt:?}"))), - }; - - Ok(builder) -} - -/// Processes a sorted row partition and writes the result to the given output path. -#[allow(clippy::too_many_arguments)] -pub fn process_sorted_row_partition( - row_num: usize, - batch_size: usize, - row_addresses_ptr: *mut jlong, - row_sizes_ptr: *mut jint, - schema: &[DataType], - output_path: String, - prefer_dictionary_ratio: f64, - checksum_enabled: bool, - checksum_algo: i32, - // This is the checksum value passed in from Spark side, and is getting updated for - // each shuffle partition Spark processes. It is called "initial" here to indicate - // this is the initial checksum for this method, as it also gets updated iteratively - // inside the loop within the method across batches. - initial_checksum: Option, - codec: &CompressionCodec, -) -> Result<(i64, Option), CometError> { - // The current row number we are reading - let mut current_row = 0; - // Total number of bytes written - let mut written = 0; - // The current checksum value. This is updated incrementally in the following loop. - let mut current_checksum = if checksum_enabled { - Some(Checksum::try_new(checksum_algo, initial_checksum)?) - } else { - None - }; - - // Create builders once and reuse them across batches. - // After finish() is called, builders are reset and can be reused. - let mut data_builders: Vec> = vec![]; - schema.iter().try_for_each(|dt| { - make_builders(dt, batch_size, prefer_dictionary_ratio) - .map(|builder| data_builders.push(builder))?; - Ok::<(), CometError>(()) - })?; - - // Open the output file once and reuse it across batches - let mut output_data = OpenOptions::new() - .create(true) - .append(true) - .open(&output_path)?; - - // Reusable buffer for serialized batch data - let mut frozen: Vec = Vec::new(); - - while current_row < row_num { - let n = std::cmp::min(batch_size, row_num - current_row); - - // Appends rows to the array builders. - // For each column, iterating over rows and appending values to corresponding array - // builder. - for (idx, builder) in data_builders.iter_mut().enumerate() { - append_columns( - row_addresses_ptr, - row_sizes_ptr, - current_row, - current_row + n, - schema, - idx, - builder, - prefer_dictionary_ratio, - )?; - } - - // Writes a record batch generated from the array builders to the output file. - // Note: builder_to_array calls finish() which resets the builder, making it reusable for the next batch. - let array_refs: Result, _> = data_builders - .iter_mut() - .zip(schema.iter()) - .map(|(builder, datatype)| builder_to_array(builder, datatype, prefer_dictionary_ratio)) - .collect(); - let batch = make_batch(array_refs?, n)?; - - frozen.clear(); - let mut cursor = Cursor::new(&mut frozen); - - // we do not collect metrics in Native_writeSortedFileNative - let ipc_time = Time::default(); - let block_writer = ShuffleBlockWriter::try_new(batch.schema().as_ref(), codec.clone())?; - written += block_writer.write_batch(&batch, &mut cursor, &ipc_time)?; - - if let Some(checksum) = &mut current_checksum { - checksum.update(&mut cursor)?; - } - - output_data.write_all(&frozen)?; - current_row += n; - } - - Ok((written as i64, current_checksum.map(|c| c.finalize()))) -} - -fn builder_to_array( - builder: &mut Box, - datatype: &DataType, - prefer_dictionary_ratio: f64, -) -> Result { - match datatype { - // We don't have redundant dictionary values which are not referenced by any key. - // So the reasonable ratio must be larger than 1.0. - DataType::Utf8 if prefer_dictionary_ratio > 1.0 => { - let builder = builder - .as_any_mut() - .downcast_mut::>() - .expect("StringDictionaryBuilder"); - - let dict_array = builder.finish(); - let num_keys = dict_array.keys().len(); - let num_values = dict_array.values().len(); - - if num_keys as f64 > num_values as f64 * prefer_dictionary_ratio { - // The number of keys in the dictionary is less than a ratio of the number of - // values. The dictionary is efficient, so we return it directly. - Ok(Arc::new(dict_array)) - } else { - // If the dictionary is not efficient, we convert it to a plain string array. - Ok(cast(&dict_array, &DataType::Utf8)?) - } - } - DataType::Binary if prefer_dictionary_ratio > 1.0 => { - let builder = builder - .as_any_mut() - .downcast_mut::>() - .expect("BinaryDictionaryBuilder"); - - let dict_array = builder.finish(); - let num_keys = dict_array.keys().len(); - let num_values = dict_array.values().len(); - - if num_keys as f64 > num_values as f64 * prefer_dictionary_ratio { - // The number of keys in the dictionary is less than a ratio of the number of - // values. The dictionary is efficient, so we return it directly. - Ok(Arc::new(dict_array)) - } else { - // If the dictionary is not efficient, we convert it to a plain string array. - Ok(cast(&dict_array, &DataType::Binary)?) - } - } - _ => Ok(builder.finish()), - } -} - -fn make_batch(arrays: Vec, row_count: usize) -> Result { - let fields = arrays - .iter() - .enumerate() - .map(|(i, array)| Field::new(format!("c{i}"), array.data_type().clone(), true)) - .collect::>(); - let schema = Arc::new(Schema::new(fields)); - let options = RecordBatchOptions::new().with_row_count(Option::from(row_count)); - RecordBatch::try_new_with_options(schema, arrays, &options) -} - -#[cfg(test)] -mod test { - use arrow::datatypes::Fields; - - use super::*; - - #[test] - fn test_append_null_row_to_struct_builder() { - let data_type = DataType::Struct(Fields::from(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("b", DataType::Boolean, true), - ])); - let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]); - let mut struct_builder = StructBuilder::from_fields(fields, 1); - let row = SparkUnsafeRow::default(); - append_field(&data_type, &mut struct_builder, &row, 0).expect("append field"); - struct_builder.append_null(); - let struct_array = struct_builder.finish(); - assert_eq!(struct_array.len(), 1); - assert!(struct_array.is_null(0)); - } - - #[test] - fn test_append_null_struct_field_to_struct_builder() { - let data_type = DataType::Struct(Fields::from(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("b", DataType::Boolean, true), - ])); - let fields = Fields::from(vec![Field::new("st", data_type.clone(), true)]); - let mut struct_builder = StructBuilder::from_fields(fields, 1); - let mut row = SparkUnsafeRow::new_with_num_fields(1); - // 8 bytes null bitset + 8 bytes field value = 16 bytes - // Set bit 0 in the null bitset to mark field 0 as null - // Use aligned buffer to match real Spark UnsafeRow layout (8-byte aligned) - #[repr(align(8))] - struct Aligned([u8; 16]); - let mut data = Aligned([0u8; 16]); - data.0[0] = 1; - row.point_to_slice(&data.0); - append_field(&data_type, &mut struct_builder, &row, 0).expect("append field"); - struct_builder.append_null(); - let struct_array = struct_builder.finish(); - assert_eq!(struct_array.len(), 1); - assert!(struct_array.is_null(0)); - } -} diff --git a/native/core/src/execution/shuffle/writers/buf_batch_writer.rs b/native/core/src/execution/shuffle/writers/buf_batch_writer.rs deleted file mode 100644 index 8d056d7bb0..0000000000 --- a/native/core/src/execution/shuffle/writers/buf_batch_writer.rs +++ /dev/null @@ -1,142 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::execution::shuffle::ShuffleBlockWriter; -use arrow::array::RecordBatch; -use arrow::compute::kernels::coalesce::BatchCoalescer; -use datafusion::physical_plan::metrics::Time; -use std::borrow::Borrow; -use std::io::{Cursor, Seek, SeekFrom, Write}; - -/// Write batches to writer while using a buffer to avoid frequent system calls. -/// The record batches were first written by ShuffleBlockWriter into an internal buffer. -/// Once the buffer exceeds the max size, the buffer will be flushed to the writer. -/// -/// Small batches are coalesced using Arrow's [`BatchCoalescer`] before serialization, -/// producing exactly `batch_size`-row output batches to reduce per-block IPC schema overhead. -/// The coalescer is lazily initialized on the first write. -pub(crate) struct BufBatchWriter, W: Write> { - shuffle_block_writer: S, - writer: W, - buffer: Vec, - buffer_max_size: usize, - /// Coalesces small batches into target_batch_size before serialization. - /// Lazily initialized on first write to capture the schema. - coalescer: Option, - /// Target batch size for coalescing - batch_size: usize, -} - -impl, W: Write> BufBatchWriter { - pub(crate) fn new( - shuffle_block_writer: S, - writer: W, - buffer_max_size: usize, - batch_size: usize, - ) -> Self { - Self { - shuffle_block_writer, - writer, - buffer: vec![], - buffer_max_size, - coalescer: None, - batch_size, - } - } - - pub(crate) fn write( - &mut self, - batch: &RecordBatch, - encode_time: &Time, - write_time: &Time, - ) -> datafusion::common::Result { - let coalescer = self - .coalescer - .get_or_insert_with(|| BatchCoalescer::new(batch.schema(), self.batch_size)); - coalescer.push_batch(batch.clone())?; - - // Drain completed batches into a local vec so the coalescer borrow ends - // before we call write_batch_to_buffer (which borrows &mut self). - let mut completed = Vec::new(); - while let Some(batch) = coalescer.next_completed_batch() { - completed.push(batch); - } - - let mut bytes_written = 0; - for batch in &completed { - bytes_written += self.write_batch_to_buffer(batch, encode_time, write_time)?; - } - Ok(bytes_written) - } - - /// Serialize a single batch into the byte buffer, flushing to the writer if needed. - fn write_batch_to_buffer( - &mut self, - batch: &RecordBatch, - encode_time: &Time, - write_time: &Time, - ) -> datafusion::common::Result { - let mut cursor = Cursor::new(&mut self.buffer); - cursor.seek(SeekFrom::End(0))?; - let bytes_written = - self.shuffle_block_writer - .borrow() - .write_batch(batch, &mut cursor, encode_time)?; - let pos = cursor.position(); - if pos >= self.buffer_max_size as u64 { - let mut write_timer = write_time.timer(); - self.writer.write_all(&self.buffer)?; - write_timer.stop(); - self.buffer.clear(); - } - Ok(bytes_written) - } - - pub(crate) fn flush( - &mut self, - encode_time: &Time, - write_time: &Time, - ) -> datafusion::common::Result<()> { - // Finish any remaining buffered rows in the coalescer - let mut remaining = Vec::new(); - if let Some(coalescer) = &mut self.coalescer { - coalescer.finish_buffered_batch()?; - while let Some(batch) = coalescer.next_completed_batch() { - remaining.push(batch); - } - } - for batch in &remaining { - self.write_batch_to_buffer(batch, encode_time, write_time)?; - } - - // Flush the byte buffer to the underlying writer - let mut write_timer = write_time.timer(); - if !self.buffer.is_empty() { - self.writer.write_all(&self.buffer)?; - } - self.writer.flush()?; - write_timer.stop(); - self.buffer.clear(); - Ok(()) - } -} - -impl, W: Write + Seek> BufBatchWriter { - pub(crate) fn writer_stream_position(&mut self) -> datafusion::common::Result { - self.writer.stream_position().map_err(Into::into) - } -} diff --git a/native/core/src/execution/shuffle/writers/mod.rs b/native/core/src/execution/shuffle/writers/mod.rs deleted file mode 100644 index d41363b7fb..0000000000 --- a/native/core/src/execution/shuffle/writers/mod.rs +++ /dev/null @@ -1,22 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod buf_batch_writer; -mod partition_writer; - -pub(super) use buf_batch_writer::BufBatchWriter; -pub(super) use partition_writer::PartitionWriter; diff --git a/native/core/src/execution/shuffle/writers/partition_writer.rs b/native/core/src/execution/shuffle/writers/partition_writer.rs deleted file mode 100644 index 7c2dbe0444..0000000000 --- a/native/core/src/execution/shuffle/writers/partition_writer.rs +++ /dev/null @@ -1,124 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::execution::shuffle::metrics::ShufflePartitionerMetrics; -use crate::execution::shuffle::partitioners::PartitionedBatchIterator; -use crate::execution::shuffle::writers::buf_batch_writer::BufBatchWriter; -use crate::execution::shuffle::ShuffleBlockWriter; -use datafusion::common::DataFusionError; -use datafusion::execution::disk_manager::RefCountedTempFile; -use datafusion::execution::runtime_env::RuntimeEnv; -use std::fs::{File, OpenOptions}; - -struct SpillFile { - temp_file: RefCountedTempFile, - file: File, -} - -pub(crate) struct PartitionWriter { - /// Spill file for intermediate shuffle output for this partition. Each spill event - /// will append to this file and the contents will be copied to the shuffle file at - /// the end of processing. - spill_file: Option, - /// Writer that performs encoding and compression - shuffle_block_writer: ShuffleBlockWriter, -} - -impl PartitionWriter { - pub(crate) fn try_new( - shuffle_block_writer: ShuffleBlockWriter, - ) -> datafusion::common::Result { - Ok(Self { - spill_file: None, - shuffle_block_writer, - }) - } - - fn ensure_spill_file_created( - &mut self, - runtime: &RuntimeEnv, - ) -> datafusion::common::Result<()> { - if self.spill_file.is_none() { - // Spill file is not yet created, create it - let spill_file = runtime - .disk_manager - .create_tmp_file("shuffle writer spill")?; - let spill_data = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open(spill_file.path()) - .map_err(|e| { - DataFusionError::Execution(format!("Error occurred while spilling {e}")) - })?; - self.spill_file = Some(SpillFile { - temp_file: spill_file, - file: spill_data, - }); - } - Ok(()) - } - - pub(crate) fn spill( - &mut self, - iter: &mut PartitionedBatchIterator, - runtime: &RuntimeEnv, - metrics: &ShufflePartitionerMetrics, - write_buffer_size: usize, - batch_size: usize, - ) -> datafusion::common::Result { - if let Some(batch) = iter.next() { - self.ensure_spill_file_created(runtime)?; - - let total_bytes_written = { - let mut buf_batch_writer = BufBatchWriter::new( - &mut self.shuffle_block_writer, - &mut self.spill_file.as_mut().unwrap().file, - write_buffer_size, - batch_size, - ); - let mut bytes_written = - buf_batch_writer.write(&batch?, &metrics.encode_time, &metrics.write_time)?; - for batch in iter { - let batch = batch?; - bytes_written += buf_batch_writer.write( - &batch, - &metrics.encode_time, - &metrics.write_time, - )?; - } - buf_batch_writer.flush(&metrics.encode_time, &metrics.write_time)?; - bytes_written - }; - - Ok(total_bytes_written) - } else { - Ok(0) - } - } - - pub(crate) fn path(&self) -> Option<&std::path::Path> { - self.spill_file - .as_ref() - .map(|spill_file| spill_file.temp_file.path()) - } - - #[cfg(test)] - pub(crate) fn has_spill_file(&self) -> bool { - self.spill_file.is_some() - } -} From 17ba839838a09c95c19770c08ad664762a48a552 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:31:22 -0600 Subject: [PATCH 6/9] feat: move shuffle benchmarks to shuffle crate Move shuffle_writer and row_columnar benchmarks from native/core to native/shuffle, updating imports to reference datafusion-comet-shuffle directly and removing the stale bench entries from core's Cargo.toml. --- native/core/Cargo.toml | 8 - native/core/benches/row_columnar.rs | 393 ----------------------- native/core/benches/shuffle_writer.rs | 212 ------------ native/shuffle/benches/row_columnar.rs | 377 +++++++++++++++++++++- native/shuffle/benches/shuffle_writer.rs | 196 ++++++++++- 5 files changed, 571 insertions(+), 615 deletions(-) delete mode 100644 native/core/benches/row_columnar.rs delete mode 100644 native/core/benches/shuffle_writer.rs diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 256bf39e20..ee58d2f1fc 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -123,14 +123,6 @@ harness = false name = "bit_util" harness = false -[[bench]] -name = "row_columnar" -harness = false - -[[bench]] -name = "shuffle_writer" -harness = false - [[bench]] name = "parquet_decode" harness = false diff --git a/native/core/benches/row_columnar.rs b/native/core/benches/row_columnar.rs deleted file mode 100644 index 4ee1539060..0000000000 --- a/native/core/benches/row_columnar.rs +++ /dev/null @@ -1,393 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Benchmarks for JVM shuffle row-to-columnar conversion. -//! -//! Measures `process_sorted_row_partition()` performance for converting Spark -//! UnsafeRow data to Arrow arrays, covering primitive, struct (flat/nested), -//! list, and map types. - -use arrow::datatypes::{DataType as ArrowDataType, Field, Fields}; -use comet::execution::shuffle::spark_unsafe::row::{ - process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow, -}; -use comet::execution::shuffle::CompressionCodec; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; -use std::sync::Arc; -use tempfile::Builder; - -const BATCH_SIZE: usize = 5000; - -/// Size of an Int64 value in bytes. -const INT64_SIZE: usize = 8; - -/// Size of a pointer in Spark's UnsafeRow format. Encodes a 32-bit offset -/// (upper bits) and 32-bit size (lower bits) — always 8 bytes regardless of -/// hardware architecture. -const UNSAFE_ROW_POINTER_SIZE: usize = 8; - -/// Size of the element-count field in UnsafeRow array/map headers. -const ARRAY_HEADER_SIZE: usize = 8; - -// ─── UnsafeRow helpers ────────────────────────────────────────────────────── - -/// Write an UnsafeRow offset+size pointer at `pos` in `data`. -fn write_pointer(data: &mut [u8], pos: usize, offset: usize, size: usize) { - let value = ((offset as i64) << 32) | (size as i64); - data[pos..pos + UNSAFE_ROW_POINTER_SIZE].copy_from_slice(&value.to_le_bytes()); -} - -/// Byte size of a null-bitset for `n` elements (64-bit words, rounded up). -fn null_bitset_size(n: usize) -> usize { - n.div_ceil(64) * 8 -} - -// ─── Schema builders ──────────────────────────────────────────────────────── - -/// Create a struct schema with `depth` nesting levels and `num_leaf_fields` -/// Int64 leaf fields. -/// -/// - depth=1: `Struct` -/// - depth=2: `Struct>` -/// - depth=3: `Struct>>` -fn make_struct_schema(depth: usize, num_leaf_fields: usize) -> ArrowDataType { - let leaf_fields: Vec = (0..num_leaf_fields) - .map(|i| Field::new(format!("f{i}"), ArrowDataType::Int64, true)) - .collect(); - let mut dt = ArrowDataType::Struct(Fields::from(leaf_fields)); - for _ in 0..depth - 1 { - dt = ArrowDataType::Struct(Fields::from(vec![Field::new("nested", dt, true)])); - } - dt -} - -fn make_list_schema() -> ArrowDataType { - ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))) -} - -fn make_map_schema() -> ArrowDataType { - let entries = Field::new( - "entries", - ArrowDataType::Struct(Fields::from(vec![ - Field::new("key", ArrowDataType::Int64, false), - Field::new("value", ArrowDataType::Int64, true), - ])), - false, - ); - ArrowDataType::Map(Arc::new(entries), false) -} - -// ─── Row data builders ────────────────────────────────────────────────────── - -/// Build a binary UnsafeRow containing a struct column with `depth` nesting -/// levels and `num_leaf_fields` Int64 fields at the innermost level. -fn build_struct_row(depth: usize, num_leaf_fields: usize) -> Vec { - let top_bitset = SparkUnsafeRow::get_row_bitset_width(1); - let inter_bitset = SparkUnsafeRow::get_row_bitset_width(1); - let leaf_bitset = SparkUnsafeRow::get_row_bitset_width(num_leaf_fields); - - let inter_level_size = inter_bitset + UNSAFE_ROW_POINTER_SIZE; - let leaf_level_size = leaf_bitset + num_leaf_fields * INT64_SIZE; - - let total = - top_bitset + UNSAFE_ROW_POINTER_SIZE + (depth - 1) * inter_level_size + leaf_level_size; - let mut data = vec![0u8; total]; - - // Absolute start position of each struct level in the buffer - let mut struct_starts = Vec::with_capacity(depth); - let mut pos = top_bitset + UNSAFE_ROW_POINTER_SIZE; - for level in 0..depth { - struct_starts.push(pos); - if level < depth - 1 { - pos += inter_level_size; - } - } - - // Top-level pointer → first struct (absolute offset from row start) - let first_size = if depth == 1 { - leaf_level_size - } else { - inter_level_size - }; - write_pointer(&mut data, top_bitset, struct_starts[0], first_size); - - // Intermediate struct pointers (offsets relative to their own struct start) - for level in 0..depth - 1 { - let next_size = if level + 1 == depth - 1 { - leaf_level_size - } else { - inter_level_size - }; - write_pointer( - &mut data, - struct_starts[level] + inter_bitset, - struct_starts[level + 1] - struct_starts[level], - next_size, - ); - } - - // Fill leaf struct with sample data - let leaf_start = *struct_starts.last().unwrap(); - for i in 0..num_leaf_fields { - let off = leaf_start + leaf_bitset + i * INT64_SIZE; - data[off..off + INT64_SIZE].copy_from_slice(&((i as i64) * 100).to_le_bytes()); - } - - data -} - -/// Build a binary UnsafeRow containing a `List` column. -fn build_list_row(num_elements: usize) -> Vec { - let top_bitset = SparkUnsafeRow::get_row_bitset_width(1); - let elem_null_bitset = null_bitset_size(num_elements); - let list_size = ARRAY_HEADER_SIZE + elem_null_bitset + num_elements * INT64_SIZE; - let total = top_bitset + UNSAFE_ROW_POINTER_SIZE + list_size; - let mut data = vec![0u8; total]; - - let list_offset = top_bitset + UNSAFE_ROW_POINTER_SIZE; - write_pointer(&mut data, top_bitset, list_offset, list_size); - - // Element count - data[list_offset..list_offset + ARRAY_HEADER_SIZE] - .copy_from_slice(&(num_elements as i64).to_le_bytes()); - - // Element values - let data_start = list_offset + ARRAY_HEADER_SIZE + elem_null_bitset; - for i in 0..num_elements { - let off = data_start + i * INT64_SIZE; - data[off..off + INT64_SIZE].copy_from_slice(&((i as i64) * 100).to_le_bytes()); - } - - data -} - -/// Build a binary UnsafeRow containing a `Map` column. -fn build_map_row(num_entries: usize) -> Vec { - let top_bitset = SparkUnsafeRow::get_row_bitset_width(1); - let entry_null_bitset = null_bitset_size(num_entries); - let array_size = ARRAY_HEADER_SIZE + entry_null_bitset + num_entries * INT64_SIZE; - // Map layout: [key_array_size header] [key_array] [value_array] - let map_size = ARRAY_HEADER_SIZE + 2 * array_size; - let total = top_bitset + UNSAFE_ROW_POINTER_SIZE + map_size; - let mut data = vec![0u8; total]; - - let map_offset = top_bitset + UNSAFE_ROW_POINTER_SIZE; - write_pointer(&mut data, top_bitset, map_offset, map_size); - - // Key array size header - data[map_offset..map_offset + ARRAY_HEADER_SIZE] - .copy_from_slice(&(array_size as i64).to_le_bytes()); - - // Key array: [element count] [null bitset] [data] - let key_offset = map_offset + ARRAY_HEADER_SIZE; - data[key_offset..key_offset + ARRAY_HEADER_SIZE] - .copy_from_slice(&(num_entries as i64).to_le_bytes()); - let key_data = key_offset + ARRAY_HEADER_SIZE + entry_null_bitset; - for i in 0..num_entries { - let off = key_data + i * INT64_SIZE; - data[off..off + INT64_SIZE].copy_from_slice(&(i as i64).to_le_bytes()); - } - - // Value array: [element count] [null bitset] [data] - let val_offset = key_offset + array_size; - data[val_offset..val_offset + ARRAY_HEADER_SIZE] - .copy_from_slice(&(num_entries as i64).to_le_bytes()); - let val_data = val_offset + ARRAY_HEADER_SIZE + entry_null_bitset; - for i in 0..num_entries { - let off = val_data + i * INT64_SIZE; - data[off..off + INT64_SIZE].copy_from_slice(&((i as i64) * 100).to_le_bytes()); - } - - data -} - -// ─── Benchmark runner ─────────────────────────────────────────────────────── - -/// Common benchmark harness: wraps raw row bytes in SparkUnsafeRow and runs -/// `process_sorted_row_partition` under Criterion. -fn run_benchmark( - group: &mut criterion::BenchmarkGroup, - name: &str, - param: &str, - schema: &[ArrowDataType], - rows: &[Vec], - num_top_level_fields: usize, -) { - let num_rows = rows.len(); - - let spark_rows: Vec = rows - .iter() - .map(|data| { - let mut row = SparkUnsafeRow::new_with_num_fields(num_top_level_fields); - row.point_to_slice(data); - for i in 0..num_top_level_fields { - row.set_not_null_at(i); - } - row - }) - .collect(); - - let mut addrs: Vec = spark_rows.iter().map(|r| r.get_row_addr()).collect(); - let mut sizes: Vec = spark_rows.iter().map(|r| r.get_row_size()).collect(); - let addr_ptr = addrs.as_mut_ptr(); - let size_ptr = sizes.as_mut_ptr(); - - group.bench_with_input(BenchmarkId::new(name, param), &num_rows, |b, &n| { - b.iter(|| { - let tmp = Builder::new().tempfile().unwrap(); - process_sorted_row_partition( - n, - BATCH_SIZE, - addr_ptr, - size_ptr, - schema, - tmp.path().to_str().unwrap().to_string(), - 1.0, - false, - 0, - None, - &CompressionCodec::Zstd(1), - ) - .unwrap(); - }); - }); - - drop(spark_rows); -} - -// ─── Benchmarks ───────────────────────────────────────────────────────────── - -/// 100 primitive Int64 columns — baseline without complex-type overhead. -fn benchmark_primitive_columns(c: &mut Criterion) { - let mut group = c.benchmark_group("primitive_columns"); - const NUM_COLS: usize = 100; - let bitset = SparkUnsafeRow::get_row_bitset_width(NUM_COLS); - let row_size = bitset + NUM_COLS * INT64_SIZE; - - for num_rows in [1000, 10000] { - let schema = vec![ArrowDataType::Int64; NUM_COLS]; - let rows: Vec> = (0..num_rows) - .map(|_| { - let mut data = vec![0u8; row_size]; - for (i, byte) in data.iter_mut().enumerate().take(row_size).skip(bitset) { - *byte = i as u8; - } - data - }) - .collect(); - - run_benchmark( - &mut group, - "cols_100", - &format!("rows_{num_rows}"), - &schema, - &rows, - NUM_COLS, - ); - } - - group.finish(); -} - -/// Struct columns at varying nesting depths (1 = flat, 2 = nested, 3 = deeply nested). -fn benchmark_struct_conversion(c: &mut Criterion) { - let mut group = c.benchmark_group("struct_conversion"); - - for (depth, label) in [(1, "flat"), (2, "nested"), (3, "deeply_nested")] { - for num_fields in [5, 10, 20] { - for num_rows in [1000, 10000] { - let schema = vec![make_struct_schema(depth, num_fields)]; - let rows: Vec> = (0..num_rows) - .map(|_| build_struct_row(depth, num_fields)) - .collect(); - - run_benchmark( - &mut group, - &format!("{label}_fields_{num_fields}"), - &format!("rows_{num_rows}"), - &schema, - &rows, - 1, - ); - } - } - } - - group.finish(); -} - -/// List columns with varying element counts. -fn benchmark_list_conversion(c: &mut Criterion) { - let mut group = c.benchmark_group("list_conversion"); - - for num_elements in [10, 100] { - for num_rows in [1000, 10000] { - let schema = vec![make_list_schema()]; - let rows: Vec> = (0..num_rows) - .map(|_| build_list_row(num_elements)) - .collect(); - - run_benchmark( - &mut group, - &format!("elements_{num_elements}"), - &format!("rows_{num_rows}"), - &schema, - &rows, - 1, - ); - } - } - - group.finish(); -} - -/// Map columns with varying entry counts. -fn benchmark_map_conversion(c: &mut Criterion) { - let mut group = c.benchmark_group("map_conversion"); - - for num_entries in [10, 100] { - for num_rows in [1000, 10000] { - let schema = vec![make_map_schema()]; - let rows: Vec> = (0..num_rows).map(|_| build_map_row(num_entries)).collect(); - - run_benchmark( - &mut group, - &format!("entries_{num_entries}"), - &format!("rows_{num_rows}"), - &schema, - &rows, - 1, - ); - } - } - - group.finish(); -} - -fn config() -> Criterion { - Criterion::default() -} - -criterion_group! { - name = benches; - config = config(); - targets = benchmark_primitive_columns, - benchmark_struct_conversion, - benchmark_list_conversion, - benchmark_map_conversion -} -criterion_main!(benches); diff --git a/native/core/benches/shuffle_writer.rs b/native/core/benches/shuffle_writer.rs deleted file mode 100644 index 0857ef78c6..0000000000 --- a/native/core/benches/shuffle_writer.rs +++ /dev/null @@ -1,212 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; -use arrow::array::{builder::StringBuilder, Array, Int32Array, RecordBatch}; -use arrow::datatypes::{DataType, Field, Schema}; -use arrow::row::{RowConverter, SortField}; -use comet::execution::shuffle::{ - CometPartitioning, CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec, -}; -use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion::datasource::memory::MemorySourceConfig; -use datafusion::datasource::source::DataSourceExec; -use datafusion::physical_expr::expressions::{col, Column}; -use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion::physical_plan::metrics::Time; -use datafusion::{ - physical_plan::{common::collect, ExecutionPlan}, - prelude::SessionContext, -}; -use itertools::Itertools; -use std::io::Cursor; -use std::sync::Arc; -use tokio::runtime::Runtime; - -fn criterion_benchmark(c: &mut Criterion) { - let batch = create_batch(8192, true); - let mut group = c.benchmark_group("shuffle_writer"); - for compression_codec in &[ - CompressionCodec::None, - CompressionCodec::Lz4Frame, - CompressionCodec::Snappy, - CompressionCodec::Zstd(1), - CompressionCodec::Zstd(6), - ] { - let name = format!("shuffle_writer: write encoded (compression={compression_codec:?})"); - group.bench_function(name, |b| { - let mut buffer = vec![]; - let ipc_time = Time::default(); - let w = - ShuffleBlockWriter::try_new(&batch.schema(), compression_codec.clone()).unwrap(); - b.iter(|| { - buffer.clear(); - let mut cursor = Cursor::new(&mut buffer); - w.write_batch(&batch, &mut cursor, &ipc_time).unwrap(); - }); - }); - } - - for compression_codec in [ - CompressionCodec::None, - CompressionCodec::Lz4Frame, - CompressionCodec::Snappy, - CompressionCodec::Zstd(1), - CompressionCodec::Zstd(6), - ] { - group.bench_function( - format!("shuffle_writer: end to end (compression = {compression_codec:?})"), - |b| { - let ctx = SessionContext::new(); - let exec = create_shuffle_writer_exec( - compression_codec.clone(), - CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16), - ); - b.iter(|| { - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx).unwrap(); - let rt = Runtime::new().unwrap(); - rt.block_on(collect(stream)).unwrap(); - }); - }, - ); - } - - let lex_ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( - col("c0", batch.schema().as_ref()).unwrap(), - )]) - .unwrap(); - - let sort_fields: Vec = batch - .columns() - .iter() - .zip(&lex_ordering) - .map(|(array, sort_expr)| { - SortField::new_with_options(array.data_type().clone(), sort_expr.options) - }) - .collect(); - let row_converter = RowConverter::new(sort_fields).unwrap(); - - // These are hard-coded values based on the benchmark params of 8192 rows per batch, and 16 - // partitions. If these change, these values need to be recalculated, or bring over the - // bounds-finding logic from shuffle_write_test in shuffle_writer.rs. - let bounds_ints = vec![ - 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, - ]; - let bounds_array: Arc = Arc::new(Int32Array::from(bounds_ints)); - let bounds_rows = row_converter - .convert_columns(vec![bounds_array].as_slice()) - .unwrap(); - - let owned_rows = bounds_rows.iter().map(|row| row.owned()).collect_vec(); - - for partitioning in [ - CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16), - CometPartitioning::RangePartitioning(lex_ordering, 16, Arc::new(row_converter), owned_rows), - ] { - let compression_codec = CompressionCodec::None; - group.bench_function( - format!("shuffle_writer: end to end (partitioning={partitioning:?})"), - |b| { - let ctx = SessionContext::new(); - let exec = - create_shuffle_writer_exec(compression_codec.clone(), partitioning.clone()); - b.iter(|| { - let task_ctx = ctx.task_ctx(); - let stream = exec.execute(0, task_ctx).unwrap(); - let rt = Runtime::new().unwrap(); - rt.block_on(collect(stream)).unwrap(); - }); - }, - ); - } -} - -fn create_shuffle_writer_exec( - compression_codec: CompressionCodec, - partitioning: CometPartitioning, -) -> ShuffleWriterExec { - let batches = create_batches(8192, 10); - let schema = batches[0].schema(); - let partitions = &[batches]; - ShuffleWriterExec::try_new( - Arc::new(DataSourceExec::new(Arc::new( - MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(), - ))), - partitioning, - compression_codec, - "/tmp/data.out".to_string(), - "/tmp/index.out".to_string(), - false, - 1024 * 1024, - ) - .unwrap() -} - -fn create_batches(size: usize, count: usize) -> Vec { - let batch = create_batch(size, true); - let mut batches = Vec::new(); - for _ in 0..count { - batches.push(batch.clone()); - } - batches -} - -fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![ - Field::new("c0", DataType::Int32, true), - Field::new("c1", DataType::Utf8, true), - Field::new("c2", DataType::Date32, true), - Field::new("c3", DataType::Decimal128(11, 2), true), - ])); - let mut a = Int32Builder::new(); - let mut b = StringBuilder::new(); - let mut c = Date32Builder::new(); - let mut d = Decimal128Builder::new() - .with_precision_and_scale(11, 2) - .unwrap(); - for i in 0..num_rows { - a.append_value(i as i32); - c.append_value(i as i32); - d.append_value((i * 1000000) as i128); - if allow_nulls && i % 10 == 0 { - b.append_null(); - } else { - b.append_value(format!("this is string number {i}")); - } - } - let a = a.finish(); - let b = b.finish(); - let c = c.finish(); - let d = d.finish(); - RecordBatch::try_new( - schema.clone(), - vec![Arc::new(a), Arc::new(b), Arc::new(c), Arc::new(d)], - ) - .unwrap() -} - -fn config() -> Criterion { - Criterion::default() -} - -criterion_group! { - name = benches; - config = config(); - targets = criterion_benchmark -} -criterion_main!(benches); diff --git a/native/shuffle/benches/row_columnar.rs b/native/shuffle/benches/row_columnar.rs index f0e5818f8b..2e2edbb5e6 100644 --- a/native/shuffle/benches/row_columnar.rs +++ b/native/shuffle/benches/row_columnar.rs @@ -15,4 +15,379 @@ // specific language governing permissions and limitations // under the License. -fn main() {} +//! Benchmarks for JVM shuffle row-to-columnar conversion. +//! +//! Measures `process_sorted_row_partition()` performance for converting Spark +//! UnsafeRow data to Arrow arrays, covering primitive, struct (flat/nested), +//! list, and map types. + +use arrow::datatypes::{DataType as ArrowDataType, Field, Fields}; +use datafusion_comet_shuffle::spark_unsafe::row::{ + process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow, +}; +use datafusion_comet_shuffle::CompressionCodec; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use std::sync::Arc; +use tempfile::Builder; + +const BATCH_SIZE: usize = 5000; + +/// Size of an Int64 value in bytes. +const INT64_SIZE: usize = 8; + +/// Size of a pointer in Spark's UnsafeRow format. Encodes a 32-bit offset +/// (upper bits) and 32-bit size (lower bits) — always 8 bytes regardless of +/// hardware architecture. +const UNSAFE_ROW_POINTER_SIZE: usize = 8; + +/// Size of the element-count field in UnsafeRow array/map headers. +const ARRAY_HEADER_SIZE: usize = 8; + +// ─── UnsafeRow helpers ────────────────────────────────────────────────────── + +/// Write an UnsafeRow offset+size pointer at `pos` in `data`. +fn write_pointer(data: &mut [u8], pos: usize, offset: usize, size: usize) { + let value = ((offset as i64) << 32) | (size as i64); + data[pos..pos + UNSAFE_ROW_POINTER_SIZE].copy_from_slice(&value.to_le_bytes()); +} + +/// Byte size of a null-bitset for `n` elements (64-bit words, rounded up). +fn null_bitset_size(n: usize) -> usize { + n.div_ceil(64) * 8 +} + +// ─── Schema builders ──────────────────────────────────────────────────────── + +/// Create a struct schema with `depth` nesting levels and `num_leaf_fields` +/// Int64 leaf fields. +/// +/// - depth=1: `Struct` +/// - depth=2: `Struct>` +/// - depth=3: `Struct>>` +fn make_struct_schema(depth: usize, num_leaf_fields: usize) -> ArrowDataType { + let leaf_fields: Vec = (0..num_leaf_fields) + .map(|i| Field::new(format!("f{i}"), ArrowDataType::Int64, true)) + .collect(); + let mut dt = ArrowDataType::Struct(Fields::from(leaf_fields)); + for _ in 0..depth - 1 { + dt = ArrowDataType::Struct(Fields::from(vec![Field::new("nested", dt, true)])); + } + dt +} + +fn make_list_schema() -> ArrowDataType { + ArrowDataType::List(Arc::new(Field::new("item", ArrowDataType::Int64, true))) +} + +fn make_map_schema() -> ArrowDataType { + let entries = Field::new( + "entries", + ArrowDataType::Struct(Fields::from(vec![ + Field::new("key", ArrowDataType::Int64, false), + Field::new("value", ArrowDataType::Int64, true), + ])), + false, + ); + ArrowDataType::Map(Arc::new(entries), false) +} + +// ─── Row data builders ────────────────────────────────────────────────────── + +/// Build a binary UnsafeRow containing a struct column with `depth` nesting +/// levels and `num_leaf_fields` Int64 fields at the innermost level. +fn build_struct_row(depth: usize, num_leaf_fields: usize) -> Vec { + let top_bitset = SparkUnsafeRow::get_row_bitset_width(1); + let inter_bitset = SparkUnsafeRow::get_row_bitset_width(1); + let leaf_bitset = SparkUnsafeRow::get_row_bitset_width(num_leaf_fields); + + let inter_level_size = inter_bitset + UNSAFE_ROW_POINTER_SIZE; + let leaf_level_size = leaf_bitset + num_leaf_fields * INT64_SIZE; + + let total = + top_bitset + UNSAFE_ROW_POINTER_SIZE + (depth - 1) * inter_level_size + leaf_level_size; + let mut data = vec![0u8; total]; + + // Absolute start position of each struct level in the buffer + let mut struct_starts = Vec::with_capacity(depth); + let mut pos = top_bitset + UNSAFE_ROW_POINTER_SIZE; + for level in 0..depth { + struct_starts.push(pos); + if level < depth - 1 { + pos += inter_level_size; + } + } + + // Top-level pointer → first struct (absolute offset from row start) + let first_size = if depth == 1 { + leaf_level_size + } else { + inter_level_size + }; + write_pointer(&mut data, top_bitset, struct_starts[0], first_size); + + // Intermediate struct pointers (offsets relative to their own struct start) + for level in 0..depth - 1 { + let next_size = if level + 1 == depth - 1 { + leaf_level_size + } else { + inter_level_size + }; + write_pointer( + &mut data, + struct_starts[level] + inter_bitset, + struct_starts[level + 1] - struct_starts[level], + next_size, + ); + } + + // Fill leaf struct with sample data + let leaf_start = *struct_starts.last().unwrap(); + for i in 0..num_leaf_fields { + let off = leaf_start + leaf_bitset + i * INT64_SIZE; + data[off..off + INT64_SIZE].copy_from_slice(&((i as i64) * 100).to_le_bytes()); + } + + data +} + +/// Build a binary UnsafeRow containing a `List` column. +fn build_list_row(num_elements: usize) -> Vec { + let top_bitset = SparkUnsafeRow::get_row_bitset_width(1); + let elem_null_bitset = null_bitset_size(num_elements); + let list_size = ARRAY_HEADER_SIZE + elem_null_bitset + num_elements * INT64_SIZE; + let total = top_bitset + UNSAFE_ROW_POINTER_SIZE + list_size; + let mut data = vec![0u8; total]; + + let list_offset = top_bitset + UNSAFE_ROW_POINTER_SIZE; + write_pointer(&mut data, top_bitset, list_offset, list_size); + + // Element count + data[list_offset..list_offset + ARRAY_HEADER_SIZE] + .copy_from_slice(&(num_elements as i64).to_le_bytes()); + + // Element values + let data_start = list_offset + ARRAY_HEADER_SIZE + elem_null_bitset; + for i in 0..num_elements { + let off = data_start + i * INT64_SIZE; + data[off..off + INT64_SIZE].copy_from_slice(&((i as i64) * 100).to_le_bytes()); + } + + data +} + +/// Build a binary UnsafeRow containing a `Map` column. +fn build_map_row(num_entries: usize) -> Vec { + let top_bitset = SparkUnsafeRow::get_row_bitset_width(1); + let entry_null_bitset = null_bitset_size(num_entries); + let array_size = ARRAY_HEADER_SIZE + entry_null_bitset + num_entries * INT64_SIZE; + // Map layout: [key_array_size header] [key_array] [value_array] + let map_size = ARRAY_HEADER_SIZE + 2 * array_size; + let total = top_bitset + UNSAFE_ROW_POINTER_SIZE + map_size; + let mut data = vec![0u8; total]; + + let map_offset = top_bitset + UNSAFE_ROW_POINTER_SIZE; + write_pointer(&mut data, top_bitset, map_offset, map_size); + + // Key array size header + data[map_offset..map_offset + ARRAY_HEADER_SIZE] + .copy_from_slice(&(array_size as i64).to_le_bytes()); + + // Key array: [element count] [null bitset] [data] + let key_offset = map_offset + ARRAY_HEADER_SIZE; + data[key_offset..key_offset + ARRAY_HEADER_SIZE] + .copy_from_slice(&(num_entries as i64).to_le_bytes()); + let key_data = key_offset + ARRAY_HEADER_SIZE + entry_null_bitset; + for i in 0..num_entries { + let off = key_data + i * INT64_SIZE; + data[off..off + INT64_SIZE].copy_from_slice(&(i as i64).to_le_bytes()); + } + + // Value array: [element count] [null bitset] [data] + let val_offset = key_offset + array_size; + data[val_offset..val_offset + ARRAY_HEADER_SIZE] + .copy_from_slice(&(num_entries as i64).to_le_bytes()); + let val_data = val_offset + ARRAY_HEADER_SIZE + entry_null_bitset; + for i in 0..num_entries { + let off = val_data + i * INT64_SIZE; + data[off..off + INT64_SIZE].copy_from_slice(&((i as i64) * 100).to_le_bytes()); + } + + data +} + +// ─── Benchmark runner ─────────────────────────────────────────────────────── + +/// Common benchmark harness: wraps raw row bytes in SparkUnsafeRow and runs +/// `process_sorted_row_partition` under Criterion. +fn run_benchmark( + group: &mut criterion::BenchmarkGroup, + name: &str, + param: &str, + schema: &[ArrowDataType], + rows: &[Vec], + num_top_level_fields: usize, +) { + let num_rows = rows.len(); + + let spark_rows: Vec = rows + .iter() + .map(|data| { + let mut row = SparkUnsafeRow::new_with_num_fields(num_top_level_fields); + row.point_to_slice(data); + for i in 0..num_top_level_fields { + row.set_not_null_at(i); + } + row + }) + .collect(); + + let mut addrs: Vec = spark_rows.iter().map(|r| r.get_row_addr()).collect(); + let mut sizes: Vec = spark_rows.iter().map(|r| r.get_row_size()).collect(); + let addr_ptr = addrs.as_mut_ptr(); + let size_ptr = sizes.as_mut_ptr(); + + group.bench_with_input(BenchmarkId::new(name, param), &num_rows, |b, &n| { + b.iter(|| { + let tmp = Builder::new().tempfile().unwrap(); + process_sorted_row_partition( + n, + BATCH_SIZE, + addr_ptr, + size_ptr, + schema, + tmp.path().to_str().unwrap().to_string(), + 1.0, + false, + 0, + None, + &CompressionCodec::Zstd(1), + ) + .unwrap(); + }); + }); + + drop(spark_rows); +} + +// ─── Benchmarks ───────────────────────────────────────────────────────────── + +/// 100 primitive Int64 columns — baseline without complex-type overhead. +fn benchmark_primitive_columns(c: &mut Criterion) { + let mut group = c.benchmark_group("primitive_columns"); + const NUM_COLS: usize = 100; + let bitset = SparkUnsafeRow::get_row_bitset_width(NUM_COLS); + let row_size = bitset + NUM_COLS * INT64_SIZE; + + for num_rows in [1000, 10000] { + let schema = vec![ArrowDataType::Int64; NUM_COLS]; + let rows: Vec> = (0..num_rows) + .map(|_| { + let mut data = vec![0u8; row_size]; + for (i, byte) in data.iter_mut().enumerate().take(row_size).skip(bitset) { + *byte = i as u8; + } + data + }) + .collect(); + + run_benchmark( + &mut group, + "cols_100", + &format!("rows_{num_rows}"), + &schema, + &rows, + NUM_COLS, + ); + } + + group.finish(); +} + +/// Struct columns at varying nesting depths (1 = flat, 2 = nested, 3 = deeply nested). +fn benchmark_struct_conversion(c: &mut Criterion) { + let mut group = c.benchmark_group("struct_conversion"); + + for (depth, label) in [(1, "flat"), (2, "nested"), (3, "deeply_nested")] { + for num_fields in [5, 10, 20] { + for num_rows in [1000, 10000] { + let schema = vec![make_struct_schema(depth, num_fields)]; + let rows: Vec> = (0..num_rows) + .map(|_| build_struct_row(depth, num_fields)) + .collect(); + + run_benchmark( + &mut group, + &format!("{label}_fields_{num_fields}"), + &format!("rows_{num_rows}"), + &schema, + &rows, + 1, + ); + } + } + } + + group.finish(); +} + +/// List columns with varying element counts. +fn benchmark_list_conversion(c: &mut Criterion) { + let mut group = c.benchmark_group("list_conversion"); + + for num_elements in [10, 100] { + for num_rows in [1000, 10000] { + let schema = vec![make_list_schema()]; + let rows: Vec> = (0..num_rows) + .map(|_| build_list_row(num_elements)) + .collect(); + + run_benchmark( + &mut group, + &format!("elements_{num_elements}"), + &format!("rows_{num_rows}"), + &schema, + &rows, + 1, + ); + } + } + + group.finish(); +} + +/// Map columns with varying entry counts. +fn benchmark_map_conversion(c: &mut Criterion) { + let mut group = c.benchmark_group("map_conversion"); + + for num_entries in [10, 100] { + for num_rows in [1000, 10000] { + let schema = vec![make_map_schema()]; + let rows: Vec> = (0..num_rows).map(|_| build_map_row(num_entries)).collect(); + + run_benchmark( + &mut group, + &format!("entries_{num_entries}"), + &format!("rows_{num_rows}"), + &schema, + &rows, + 1, + ); + } + } + + group.finish(); +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = benchmark_primitive_columns, + benchmark_struct_conversion, + benchmark_list_conversion, + benchmark_map_conversion +} +criterion_main!(benches); diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs index f0e5818f8b..11f2a73e2e 100644 --- a/native/shuffle/benches/shuffle_writer.rs +++ b/native/shuffle/benches/shuffle_writer.rs @@ -15,4 +15,198 @@ // specific language governing permissions and limitations // under the License. -fn main() {} +use arrow::array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; +use arrow::array::{builder::StringBuilder, Array, Int32Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow::row::{RowConverter, SortField}; +use datafusion_comet_shuffle::{ + CometPartitioning, CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec, +}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion::datasource::memory::MemorySourceConfig; +use datafusion::datasource::source::DataSourceExec; +use datafusion::physical_expr::expressions::{col, Column}; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion::physical_plan::metrics::Time; +use datafusion::{ + physical_plan::{common::collect, ExecutionPlan}, + prelude::SessionContext, +}; +use itertools::Itertools; +use std::io::Cursor; +use std::sync::Arc; +use tokio::runtime::Runtime; + +fn criterion_benchmark(c: &mut Criterion) { + let batch = create_batch(8192, true); + let mut group = c.benchmark_group("shuffle_writer"); + for compression_codec in &[ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Snappy, + CompressionCodec::Zstd(1), + CompressionCodec::Zstd(6), + ] { + let name = format!("shuffle_writer: write encoded (compression={compression_codec:?})"); + group.bench_function(name, |b| { + let mut buffer = vec![]; + let ipc_time = Time::default(); + let w = + ShuffleBlockWriter::try_new(&batch.schema(), compression_codec.clone()).unwrap(); + b.iter(|| { + buffer.clear(); + let mut cursor = Cursor::new(&mut buffer); + w.write_batch(&batch, &mut cursor, &ipc_time).unwrap(); + }); + }); + } + + for compression_codec in [ + CompressionCodec::None, + CompressionCodec::Lz4Frame, + CompressionCodec::Snappy, + CompressionCodec::Zstd(1), + CompressionCodec::Zstd(6), + ] { + group.bench_function( + format!("shuffle_writer: end to end (compression = {compression_codec:?})"), + |b| { + let ctx = SessionContext::new(); + let exec = create_shuffle_writer_exec( + compression_codec.clone(), + CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16), + ); + b.iter(|| { + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).unwrap(); + let rt = Runtime::new().unwrap(); + rt.block_on(collect(stream)).unwrap(); + }); + }, + ); + } + + let lex_ordering = LexOrdering::new(vec![PhysicalSortExpr::new_default( + col("c0", batch.schema().as_ref()).unwrap(), + )]) + .unwrap(); + + let sort_fields: Vec = batch + .columns() + .iter() + .zip(&lex_ordering) + .map(|(array, sort_expr)| { + SortField::new_with_options(array.data_type().clone(), sort_expr.options) + }) + .collect(); + let row_converter = RowConverter::new(sort_fields).unwrap(); + + // These are hard-coded values based on the benchmark params of 8192 rows per batch, and 16 + // partitions. If these change, these values need to be recalculated, or bring over the + // bounds-finding logic from shuffle_write_test in shuffle_writer.rs. + let bounds_ints = vec![ + 512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, + ]; + let bounds_array: Arc = Arc::new(Int32Array::from(bounds_ints)); + let bounds_rows = row_converter + .convert_columns(vec![bounds_array].as_slice()) + .unwrap(); + + let owned_rows = bounds_rows.iter().map(|row| row.owned()).collect_vec(); + + for partitioning in [ + CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], 16), + CometPartitioning::RangePartitioning(lex_ordering, 16, Arc::new(row_converter), owned_rows), + ] { + let compression_codec = CompressionCodec::None; + group.bench_function( + format!("shuffle_writer: end to end (partitioning={partitioning:?})"), + |b| { + let ctx = SessionContext::new(); + let exec = + create_shuffle_writer_exec(compression_codec.clone(), partitioning.clone()); + b.iter(|| { + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).unwrap(); + let rt = Runtime::new().unwrap(); + rt.block_on(collect(stream)).unwrap(); + }); + }, + ); + } +} + +fn create_shuffle_writer_exec( + compression_codec: CompressionCodec, + partitioning: CometPartitioning, +) -> ShuffleWriterExec { + let batches = create_batches(8192, 10); + let schema = batches[0].schema(); + let partitions = &[batches]; + ShuffleWriterExec::try_new( + Arc::new(DataSourceExec::new(Arc::new( + MemorySourceConfig::try_new(partitions, Arc::clone(&schema), None).unwrap(), + ))), + partitioning, + compression_codec, + "/tmp/data.out".to_string(), + "/tmp/index.out".to_string(), + false, + 1024 * 1024, + ) + .unwrap() +} + +fn create_batches(size: usize, count: usize) -> Vec { + let batch = create_batch(size, true); + let mut batches = Vec::new(); + for _ in 0..count { + batches.push(batch.clone()); + } + batches +} + +fn create_batch(num_rows: usize, allow_nulls: bool) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Date32, true), + Field::new("c3", DataType::Decimal128(11, 2), true), + ])); + let mut a = Int32Builder::new(); + let mut b = StringBuilder::new(); + let mut c = Date32Builder::new(); + let mut d = Decimal128Builder::new() + .with_precision_and_scale(11, 2) + .unwrap(); + for i in 0..num_rows { + a.append_value(i as i32); + c.append_value(i as i32); + d.append_value((i * 1000000) as i128); + if allow_nulls && i % 10 == 0 { + b.append_null(); + } else { + b.append_value(format!("this is string number {i}")); + } + } + let a = a.finish(); + let b = b.finish(); + let c = c.finish(); + let d = d.finish(); + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(a), Arc::new(b), Arc::new(c), Arc::new(d)], + ) + .unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); From 70720420e9db670bea529701905b5629f392994c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:35:01 -0600 Subject: [PATCH 7/9] chore: add missing READMEs and fix clippy warning Add README files for common, jni-bridge, and shuffle crates. Add Default impl for Recorder to satisfy clippy. --- native/common/README.md | 25 +++++++++++++++++++++++++ native/common/src/tracing.rs | 6 ++++++ native/jni-bridge/README.md | 25 +++++++++++++++++++++++++ native/shuffle/README.md | 25 +++++++++++++++++++++++++ 4 files changed, 81 insertions(+) create mode 100644 native/common/README.md create mode 100644 native/jni-bridge/README.md create mode 100644 native/shuffle/README.md diff --git a/native/common/README.md b/native/common/README.md new file mode 100644 index 0000000000..842b441b53 --- /dev/null +++ b/native/common/README.md @@ -0,0 +1,25 @@ + + +# datafusion-comet-common: Common Types + +This crate provides common types shared across Apache DataFusion Comet crates and is maintained as part of the +[Apache DataFusion Comet] subproject. + +[Apache DataFusion Comet]: https://github.com/apache/datafusion-comet/ diff --git a/native/common/src/tracing.rs b/native/common/src/tracing.rs index 76598fd5ac..58bea64a7a 100644 --- a/native/common/src/tracing.rs +++ b/native/common/src/tracing.rs @@ -29,6 +29,12 @@ pub struct Recorder { writer: Arc>>, } +impl Default for Recorder { + fn default() -> Self { + Self::new() + } +} + impl Recorder { pub fn new() -> Self { let file = OpenOptions::new() diff --git a/native/jni-bridge/README.md b/native/jni-bridge/README.md new file mode 100644 index 0000000000..d49a3c2565 --- /dev/null +++ b/native/jni-bridge/README.md @@ -0,0 +1,25 @@ + + +# datafusion-comet-jni-bridge: JNI Bridge + +This crate provides the JNI interaction layer for Apache DataFusion Comet and is maintained as part of the +[Apache DataFusion Comet] subproject. + +[Apache DataFusion Comet]: https://github.com/apache/datafusion-comet/ diff --git a/native/shuffle/README.md b/native/shuffle/README.md new file mode 100644 index 0000000000..8fba6b0323 --- /dev/null +++ b/native/shuffle/README.md @@ -0,0 +1,25 @@ + + +# datafusion-comet-shuffle: Shuffle Writer and Reader + +This crate provides the shuffle writer and reader implementation for Apache DataFusion Comet and is maintained as part +of the [Apache DataFusion Comet] subproject. + +[Apache DataFusion Comet]: https://github.com/apache/datafusion-comet/ From 2744ab1305b0d4e25101915d6c97898a2621c8ca Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 20 Mar 2026 12:52:00 -0600 Subject: [PATCH 8/9] cargo fmt --- native/shuffle/benches/row_columnar.rs | 2 +- native/shuffle/benches/shuffle_writer.rs | 6 +++--- native/shuffle/src/codec.rs | 2 +- native/shuffle/src/partitioners/multi_partition.rs | 2 +- native/shuffle/src/shuffle_writer.rs | 2 +- native/shuffle/src/spark_unsafe/list.rs | 2 +- native/shuffle/src/spark_unsafe/map.rs | 2 +- native/shuffle/src/spark_unsafe/row.rs | 4 ++-- 8 files changed, 11 insertions(+), 11 deletions(-) diff --git a/native/shuffle/benches/row_columnar.rs b/native/shuffle/benches/row_columnar.rs index 2e2edbb5e6..7d3951b4d5 100644 --- a/native/shuffle/benches/row_columnar.rs +++ b/native/shuffle/benches/row_columnar.rs @@ -22,11 +22,11 @@ //! list, and map types. use arrow::datatypes::{DataType as ArrowDataType, Field, Fields}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_comet_shuffle::spark_unsafe::row::{ process_sorted_row_partition, SparkUnsafeObject, SparkUnsafeRow, }; use datafusion_comet_shuffle::CompressionCodec; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use std::sync::Arc; use tempfile::Builder; diff --git a/native/shuffle/benches/shuffle_writer.rs b/native/shuffle/benches/shuffle_writer.rs index 11f2a73e2e..27abd919fa 100644 --- a/native/shuffle/benches/shuffle_writer.rs +++ b/native/shuffle/benches/shuffle_writer.rs @@ -19,9 +19,6 @@ use arrow::array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; use arrow::array::{builder::StringBuilder, Array, Int32Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::row::{RowConverter, SortField}; -use datafusion_comet_shuffle::{ - CometPartitioning, CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec, -}; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::datasource::source::DataSourceExec; @@ -32,6 +29,9 @@ use datafusion::{ physical_plan::{common::collect, ExecutionPlan}, prelude::SessionContext, }; +use datafusion_comet_shuffle::{ + CometPartitioning, CompressionCodec, ShuffleBlockWriter, ShuffleWriterExec, +}; use itertools::Itertools; use std::io::Cursor; use std::sync::Arc; diff --git a/native/shuffle/src/codec.rs b/native/shuffle/src/codec.rs index c18489115a..c8edc2468c 100644 --- a/native/shuffle/src/codec.rs +++ b/native/shuffle/src/codec.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_comet_jni_bridge::errors::{CometError, CometResult}; use arrow::array::RecordBatch; use arrow::datatypes::Schema; use arrow::ipc::reader::StreamReader; @@ -25,6 +24,7 @@ use crc32fast::Hasher; use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::metrics::Time; +use datafusion_comet_jni_bridge::errors::{CometError, CometResult}; use simd_adler32::Adler32; use std::io::{Cursor, Seek, SeekFrom, Write}; diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index c83f6fb9c8..42290c5510 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -22,7 +22,6 @@ use crate::partitioners::partitioned_batch_iterator::{ use crate::partitioners::ShufflePartitioner; use crate::writers::{BufBatchWriter, PartitionWriter}; use crate::{comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter}; -use datafusion_comet_common::tracing::{with_trace, with_trace_async}; use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion::common::utils::proxy::VecAllocExt; @@ -30,6 +29,7 @@ use datafusion::common::DataFusionError; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::metrics::Time; +use datafusion_comet_common::tracing::{with_trace, with_trace_async}; use datafusion_comet_spark_expr::murmur3::create_murmur3_hashes; use itertools::Itertools; use std::fmt; diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index 4b3f08a826..e649aaac69 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -22,7 +22,6 @@ use crate::partitioners::{ MultiPartitionShuffleRepartitioner, ShufflePartitioner, SinglePartitionShufflePartitioner, }; use crate::{CometPartitioning, CompressionCodec}; -use datafusion_comet_common::tracing::with_trace_async; use async_trait::async_trait; use datafusion::common::exec_datafusion_err; use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; @@ -39,6 +38,7 @@ use datafusion::{ Statistics, }, }; +use datafusion_comet_common::tracing::with_trace_async; use futures::{StreamExt, TryFutureExt, TryStreamExt}; use std::{ any::Any, diff --git a/native/shuffle/src/spark_unsafe/list.rs b/native/shuffle/src/spark_unsafe/list.rs index d3cdf0dd04..4eb293895c 100644 --- a/native/shuffle/src/spark_unsafe/list.rs +++ b/native/shuffle/src/spark_unsafe/list.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion_comet_jni_bridge::errors::CometError; use crate::spark_unsafe::{ map::append_map_elements, row::{ @@ -32,6 +31,7 @@ use arrow::array::{ MapBuilder, }; use arrow::datatypes::{DataType, TimeUnit}; +use datafusion_comet_jni_bridge::errors::CometError; /// Generates bulk append methods for primitive types in SparkUnsafeArray. /// diff --git a/native/shuffle/src/spark_unsafe/map.rs b/native/shuffle/src/spark_unsafe/map.rs index efc3069a6e..57444cee7a 100644 --- a/native/shuffle/src/spark_unsafe/map.rs +++ b/native/shuffle/src/spark_unsafe/map.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use datafusion_comet_jni_bridge::errors::CometError; use crate::spark_unsafe::list::{append_to_builder, SparkUnsafeArray}; use arrow::array::builder::{ArrayBuilder, MapBuilder, MapFieldNames}; use arrow::datatypes::{DataType, FieldRef}; +use datafusion_comet_jni_bridge::errors::CometError; pub struct SparkUnsafeMap { pub(crate) keys: SparkUnsafeArray, diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index 13a80998db..da980af8f9 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -17,8 +17,6 @@ //! Utils for supporting native sort-based columnar shuffle. -use datafusion_comet_jni_bridge::errors::CometError; -use datafusion_comet_common::bytes_to_i128; use crate::codec::{Checksum, ShuffleBlockWriter}; use crate::spark_unsafe::{ list::{append_list_element, SparkUnsafeArray}, @@ -38,6 +36,8 @@ use arrow::compute::cast; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use arrow::error::ArrowError; use datafusion::physical_plan::metrics::Time; +use datafusion_comet_common::bytes_to_i128; +use datafusion_comet_jni_bridge::errors::CometError; use jni::sys::{jint, jlong}; use std::{ fs::OpenOptions, From 12b008289a922f928a3d606e611190bc7a76f2fe Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 22 Mar 2026 08:40:53 -0600 Subject: [PATCH 9/9] chore(deps): remove unused compression dependencies from core crate Remove crc32fast, lz4_flex, simd-adler32, snap, and zstd which are no longer used after the shuffle read path was moved to the shuffle crate. --- native/Cargo.lock | 5 ----- native/core/Cargo.toml | 6 ------ 2 files changed, 11 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index ffd9773f83..bc015fc388 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1834,7 +1834,6 @@ dependencies = [ "aws-config", "aws-credential-types", "bytes", - "crc32fast", "criterion", "datafusion", "datafusion-comet-common", @@ -1858,7 +1857,6 @@ dependencies = [ "lazy_static", "log", "log4rs", - "lz4_flex 0.13.0", "mimalloc", "num", "object_store", @@ -1874,15 +1872,12 @@ dependencies = [ "rand 0.10.0", "reqwest", "serde_json", - "simd-adler32", - "snap", "tempfile", "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", "url", "uuid", - "zstd", ] [[package]] diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index ee58d2f1fc..b66830ecb5 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -47,10 +47,6 @@ log = "0.4" log4rs = "1.4.0" prost = "0.14.3" jni = "0.21" -snap = "1.1" -# we disable default features in lz4_flex to force the use of the faster unsafe encoding and decoding implementation -lz4_flex = { version = "0.13.0", default-features = false, features = ["frame"] } -zstd = "0.13.3" rand = { workspace = true } num = { workspace = true } bytes = { workspace = true } @@ -62,8 +58,6 @@ datafusion-physical-expr-adapter = { workspace = true } datafusion-datasource = { workspace = true } datafusion-spark = { workspace = true } once_cell = "1.18.0" -crc32fast = "1.3.2" -simd-adler32 = "0.3.7" datafusion-comet-common = { workspace = true } datafusion-comet-spark-expr = { workspace = true } datafusion-comet-jni-bridge = { workspace = true }