Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 103 additions & 62 deletions datafusion/functions/src/string/split_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use crate::strings::{
BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder,
};
use crate::utils::utf8_to_str_type;
use arrow::array::{
Array, ArrayRef, AsArray, ByteView, GenericStringBuilder, Int64Array,
StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder,
Array, ArrayRef, AsArray, ByteView, Int64Array, StringArrayType, StringViewArray,
make_view, new_null_array,
};
use arrow::buffer::ScalarBuffer;
use arrow::buffer::{NullBuffer, ScalarBuffer};
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
use datafusion_common::cast::as_int64_array;
Expand Down Expand Up @@ -167,7 +169,7 @@ impl ScalarUDFImpl for SplitPartFunc {
let result = match args[0].data_type() {
DataType::Utf8View => split_part_for_delimiter_type!(
&args[0].as_string_view(),
StringViewBuilder::with_capacity(inferred_length)
StringViewArrayBuilder::with_capacity(inferred_length)
),
DataType::Utf8 => {
let str_arr = &args[0].as_string::<i32>();
Expand All @@ -176,7 +178,7 @@ impl ScalarUDFImpl for SplitPartFunc {
// pre-allocating the full input data size.
split_part_for_delimiter_type!(
str_arr,
GenericStringBuilder::<i32>::with_capacity(
GenericStringArrayBuilder::<i32>::with_capacity(
inferred_length,
inferred_length,
)
Expand All @@ -187,7 +189,7 @@ impl ScalarUDFImpl for SplitPartFunc {
// Conservative under-estimate; see Utf8 comment above.
split_part_for_delimiter_type!(
str_arr,
GenericStringBuilder::<i64>::with_capacity(
GenericStringArrayBuilder::<i64>::with_capacity(
inferred_length,
inferred_length,
)
Expand Down Expand Up @@ -293,7 +295,7 @@ fn split_part_scalar(
arr,
delimiter,
position,
GenericStringBuilder::<i32>::with_capacity(arr.len(), arr.len()),
GenericStringArrayBuilder::<i32>::with_capacity(arr.len(), arr.len()),
)
}
DataType::LargeUtf8 => {
Expand All @@ -303,7 +305,7 @@ fn split_part_scalar(
arr,
delimiter,
position,
GenericStringBuilder::<i64>::with_capacity(arr.len(), arr.len()),
GenericStringArrayBuilder::<i64>::with_capacity(arr.len(), arr.len()),
)
}
other => exec_err!("Unsupported string type {other:?} for split_part"),
Expand All @@ -323,7 +325,7 @@ fn split_part_scalar_impl<'a, S, B>(
) -> Result<ArrayRef>
where
S: StringArrayType<'a> + Copy,
B: StringLikeArrayBuilder,
B: BulkNullStringArrayBuilder,
{
if delimiter.is_empty() {
// PostgreSQL: empty delimiter treats input as a single field,
Expand Down Expand Up @@ -367,16 +369,31 @@ where
fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result<ArrayRef>
where
S: StringArrayType<'a> + Copy,
B: StringLikeArrayBuilder,
B: BulkNullStringArrayBuilder,
F: Fn(&'a str) -> Option<&'a str>,
{
for string in string_array.iter() {
match string {
Some(s) => builder.append_value(f(s).unwrap_or("")),
None => builder.append_null(),
let item_len = string_array.len();
let nulls = string_array.nulls().cloned();

if let Some(ref n) = nulls {
for i in 0..item_len {
if n.is_null(i) {
builder.append_placeholder();
} else {
// SAFETY: `n.is_null(i)` was false in the branch above.
let s = unsafe { string_array.value_unchecked(i) };
builder.append_value(f(s).unwrap_or(""));
}
}
} else {
for i in 0..item_len {
// SAFETY: no null buffer means every index is valid.
let s = unsafe { string_array.value_unchecked(i) };
builder.append_value(f(s).unwrap_or(""));
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)

builder.finish(nulls)
}

/// Finds the `n`th (0-based) split part using a pre-built `memmem::Finder`.
Expand Down Expand Up @@ -543,58 +560,82 @@ fn split_part_impl<'a, StringArrType, DelimiterArrType, B>(
where
StringArrType: StringArrayType<'a>,
DelimiterArrType: StringArrayType<'a>,
B: StringLikeArrayBuilder,
B: BulkNullStringArrayBuilder,
{
for ((string, delimiter), n) in string_array
.iter()
.zip(delimiter_array.iter())
.zip(n_array.iter())
{
match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
let result = match n.cmp(&0) {
std::cmp::Ordering::Greater => {
let idx: usize = (n - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {n} exceeds maximum supported value"
)
})?;
if delimiter.is_empty() {
// Match PostgreSQL's behavior: empty delimiter
// treats input as a single field, so only position
// 1 returns data.
(n == 1).then_some(string)
} else {
split_nth(string, delimiter, idx)
}
}
std::cmp::Ordering::Less => {
let idx: usize =
(n.unsigned_abs() - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {n} exceeds minimum supported value"
)
})?;
if delimiter.is_empty() {
// Match PostgreSQL's behavior: empty delimiter
// treats input as a single field, so only position
// -1 returns data.
(n == -1).then_some(string)
} else {
rsplit_nth(string, delimiter, idx)
}
}
std::cmp::Ordering::Equal => {
return exec_err!("field position must not be zero");
}
};
builder.append_value(result.unwrap_or(""));
let nulls = NullBuffer::union_many([
string_array.nulls(),
delimiter_array.nulls(),
n_array.nulls(),
]);

if let Some(ref n) = nulls {
for i in 0..string_array.len() {
if n.is_null(i) {
builder.append_placeholder();
continue;
}
_ => builder.append_null(),

// SAFETY: the union null buffer is valid at `i`, so each input is valid.
let string = unsafe { string_array.value_unchecked(i) };
let delimiter = unsafe { delimiter_array.value_unchecked(i) };
let position = unsafe { n_array.value_unchecked(i) };
append_split_part(string, delimiter, position, &mut builder)?;
}
} else {
for i in 0..string_array.len() {
// SAFETY: no input has a null buffer, so every index is valid.
let string = unsafe { string_array.value_unchecked(i) };
let delimiter = unsafe { delimiter_array.value_unchecked(i) };
let position = unsafe { n_array.value_unchecked(i) };
append_split_part(string, delimiter, position, &mut builder)?;
}
}

Ok(Arc::new(builder.finish()) as ArrayRef)
builder.finish(nulls)
}

#[inline]
fn append_split_part<B: BulkNullStringArrayBuilder>(
string: &str,
delimiter: &str,
n: i64,
builder: &mut B,
) -> Result<()> {
let result = match n.cmp(&0) {
std::cmp::Ordering::Greater => {
let idx: usize = (n - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {n} exceeds maximum supported value"
)
})?;
if delimiter.is_empty() {
// Match PostgreSQL's behavior: empty delimiter treats input
// as a single field, so only position 1 returns data.
(n == 1).then_some(string)
} else {
split_nth(string, delimiter, idx)
}
}
std::cmp::Ordering::Less => {
let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| {
exec_datafusion_err!(
"split_part index {n} exceeds minimum supported value"
)
})?;
if delimiter.is_empty() {
// Match PostgreSQL's behavior: empty delimiter treats input
// as a single field, so only position -1 returns data.
(n == -1).then_some(string)
} else {
rsplit_nth(string, delimiter, idx)
}
}
std::cmp::Ordering::Equal => {
return exec_err!("field position must not be zero");
}
};
builder.append_value(result.unwrap_or(""));
Ok(())
}

#[cfg(test)]
Expand Down
Loading