From 61a7decf6fe44f00fcd96a27c9ae6aeb684c256d Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Mon, 11 May 2026 22:07:44 +0800 Subject: [PATCH 1/3] implement Spark substring ext function support binary type --- .../auron/utils/AuronSparkTestSettings.scala | 2 - .../datafusion-ext-functions/src/lib.rs | 1 + .../src/spark_strings.rs | 126 +++++++++++++++++- .../spark/sql/auron/NativeConverters.scala | 6 +- 4 files changed, 124 insertions(+), 11 deletions(-) diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 55672fdd4..5e96aa646 100644 --- a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -25,8 +25,6 @@ class AuronSparkTestSettings extends SparkTestSettings { } enableSuite[AuronStringFunctionsSuite] - // See https://github.com/apache/auron/issues/1724 - .exclude("string / binary substring function") enableSuite[AuronDataFrameAggregateSuite] // See https://github.com/apache/auron/issues/1840 diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 7eb1f63a6..3916d63b3 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -75,6 +75,7 @@ pub fn create_auron_ext_function( "Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws), "Spark_StringLower" => Arc::new(spark_strings::string_lower), "Spark_StringUpper" => Arc::new(spark_strings::string_upper), + "Spark_Substring" => Arc::new(spark_strings::spark_substring), "Spark_InitCap" => Arc::new(spark_initcap::string_initcap), "Spark_Year" => Arc::new(spark_dates::spark_year), "Spark_Month" => Arc::new(spark_dates::spark_month), diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 43b1f136f..2fd258591 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -16,13 +16,15 @@ use std::sync::Arc; use arrow::{ - array::{Array, ArrayRef, AsArray, ListArray, ListBuilder, StringArray, StringBuilder}, + array::{ + Array, ArrayRef, AsArray, BinaryArray, ListArray, ListBuilder, StringArray, StringBuilder, + }, datatypes::DataType, }; use datafusion::{ common::{ Result, ScalarValue, - cast::{as_int32_array, as_list_array, as_string_array}, + cast::{as_binary_array, as_int32_array, as_list_array, as_string_array}, }, physical_plan::ColumnarValue, }; @@ -114,6 +116,80 @@ pub fn string_split(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(splitted_builder.finish()))) } +pub fn spark_substring(args: &[ColumnarValue]) -> Result { + let pos = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => *value, + _ => df_execution_err!("substring pos only supports literal int64")?, + }; + let len = match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => *value, + _ => df_execution_err!("substring len only supports literal int64")?, + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let result: StringArray = as_string_array(array)? + .iter() + .map(|value| { + value.map(|value| { + let chars = value.chars().collect::>(); + let (start, end) = substring_range(chars.len(), pos, len); + chars[start..end].iter().collect::() + }) + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Binary => { + let result: BinaryArray = as_binary_array(array)? + .iter() + .map(|value| { + value.map(|value| { + let (start, end) = substring_range(value.len(), pos, len); + &value[start..end] + }) + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => df_execution_err!("substring only supports utf8 or binary, got {other:?}"), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(value)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(value.as_ref().map(|value| { + let chars = value.chars().collect::>(); + let (start, end) = substring_range(chars.len(), pos, len); + chars[start..end].iter().collect::() + })), + )), + ColumnarValue::Scalar(ScalarValue::Binary(value)) => Ok(ColumnarValue::Scalar( + ScalarValue::Binary(value.as_ref().map(|value| { + let (start, end) = substring_range(value.len(), pos, len); + value[start..end].to_vec() + })), + )), + other => df_execution_err!("substring only supports utf8 or binary, got {:?}", other), + } +} + +fn substring_range(total_len: usize, pos: i64, len: i64) -> (usize, usize) { + if len <= 0 { + return (0, 0); + } + + let total_len_i64 = total_len as i64; + let start = if pos > 0 { + pos - 1 + } else if pos < 0 { + total_len_i64 + pos + } else { + 0 + } + .clamp(0, total_len_i64) as usize; + let end = (start as i64 + len).clamp(0, total_len_i64) as usize; + (start, end) +} + /// concat() function compatible with spark (returns null if any param is null) /// concat('abcde', 2, 22) = 'abcde222 /// concat('abcde', 2, NULL, 22) = NULL @@ -322,19 +398,19 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { mod test { use std::sync::Arc; - use arrow::array::{Int32Array, ListBuilder, StringArray, StringBuilder}; + use arrow::array::{BinaryArray, Int32Array, ListBuilder, StringArray, StringBuilder}; use datafusion::{ common::{ Result, ScalarValue, - cast::{as_list_array, as_string_array}, + cast::{as_binary_array, as_list_array, as_string_array}, }, physical_plan::ColumnarValue, }; use datafusion_ext_commons::df_execution_err; use crate::spark_strings::{ - string_concat, string_concat_ws, string_lower, string_repeat, string_space, string_split, - string_upper, + spark_substring, string_concat, string_concat_ws, string_lower, string_repeat, + string_space, string_split, string_upper, }; #[test] @@ -395,6 +471,44 @@ mod test { } } + #[test] + fn test_spark_substring_string_array() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("数据fusion".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("bcd"), Some("据fu"), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_binary_array() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ + Some(&[1_u8, 2, 3, 4, 5][..]), + Some(&[9_u8, 8][..]), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let b = r.into_array(3)?; + assert_eq!( + as_binary_array(&b)?.iter().collect::>(), + vec![Some(&[2_u8, 3, 4][..]), Some(&[8_u8][..]), None] + ); + Ok(()) + } + #[test] fn test_string_repeat() -> Result<()> { // positive case diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 8551f0b0b..87fa35695 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -1024,10 +1024,10 @@ object NativeConverters extends Logging { if pos.asInstanceOf[Int] > 0 && len.asInstanceOf[Int] >= 0 => val longPos = pos.asInstanceOf[Int].toLong val longLen = len.asInstanceOf[Int].toLong - buildScalarFunction( - pb.ScalarFunction.Substr, + buildExtScalarFunction( + "Spark_Substring", str :: Literal(longPos) :: Literal(longLen) :: Nil, - StringType) + str.dataType) case StringSpace(n) => buildExtScalarFunction("Spark_StringSpace", n :: Nil, StringType) From df3a98030aba7857b02dcf57986485af43ad868d Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Tue, 12 May 2026 22:01:41 +0800 Subject: [PATCH 2/3] apply suggestions --- .../auron/utils/AuronSparkTestSettings.scala | 1 + .../src/spark_strings.rs | 200 +++++++++++++++++- 2 files changed, 196 insertions(+), 5 deletions(-) diff --git a/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 052cca5d1..01e67e004 100644 --- a/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -23,3 +23,4 @@ class AuronSparkTestSettings extends SparkTestSettings { override def getOverwriteSQLQueryTests: Set[String] = Set.empty } } + diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 2fd258591..bfd715066 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -178,15 +178,15 @@ fn substring_range(total_len: usize, pos: i64, len: i64) -> (usize, usize) { } let total_len_i64 = total_len as i64; - let start = if pos > 0 { + let raw_start = if pos > 0 { pos - 1 } else if pos < 0 { - total_len_i64 + pos + total_len_i64.saturating_add(pos) } else { 0 - } - .clamp(0, total_len_i64) as usize; - let end = (start as i64 + len).clamp(0, total_len_i64) as usize; + }; + let start = raw_start.clamp(0, total_len_i64) as usize; + let end = (start as i64).saturating_add(len).clamp(0, total_len_i64) as usize; (start, end) } @@ -490,6 +490,126 @@ mod test { Ok(()) } + #[test] + fn test_spark_substring_string_array_with_zero_and_negative_pos() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("数据fusion".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(0_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("abc"), Some("数据f"), None] + ); + + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("数据fusion".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(-3_i64)), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("de"), Some("io"), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_string_array_with_non_positive_len() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(0_i64)), + ])?; + let s = r.into_array(2)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some(""), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_string_array_with_clamped_ranges() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(10_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some(""), Some(""), None] + ); + + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(100_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("bcdef"), Some(""), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_string_array_with_extreme_args() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("abc"), Some(""), None] + ); + + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("bcdef"), Some(""), None] + ); + Ok(()) + } + #[test] fn test_spark_substring_binary_array() -> Result<()> { let r = spark_substring(&vec![ @@ -509,6 +629,76 @@ mod test { Ok(()) } + #[test] + fn test_spark_substring_binary_array_with_clamped_ranges() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ + Some(&[1_u8, 2, 3, 4, 5][..]), + Some(&[][..]), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(10_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let b = r.into_array(3)?; + assert_eq!( + as_binary_array(&b)?.iter().collect::>(), + vec![Some(&[][..]), Some(&[][..]), None] + ); + + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ + Some(&[1_u8, 2, 3, 4, 5][..]), + Some(&[][..]), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ])?; + let b = r.into_array(3)?; + assert_eq!( + as_binary_array(&b)?.iter().collect::>(), + vec![Some(&[2_u8, 3, 4, 5][..]), Some(&[][..]), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_scalar_with_edge_cases() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcdef".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => assert_eq!(value, "ab"), + other => df_execution_err!("Expected scalar Utf8 substring, got: {:?}", other)?, + } + + let r = spark_substring(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => assert_eq!(value, ""), + other => df_execution_err!("Expected empty scalar Utf8 substring, got: {:?}", other)?, + } + + let r = spark_substring(&vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![1_u8, 2, 3, 4, 5]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Binary(Some(value))) => { + assert_eq!(value, vec![2_u8, 3, 4, 5]) + } + other => df_execution_err!("Expected scalar Binary substring, got: {:?}", other)?, + } + Ok(()) + } + #[test] fn test_string_repeat() -> Result<()> { // positive case From e7e5f8c82af3f1a996a0f8b799a4d9505263c941 Mon Sep 17 00:00:00 2001 From: lyne7-sc <734432041@qq.com> Date: Tue, 12 May 2026 22:14:06 +0800 Subject: [PATCH 3/3] clean code --- .../src/spark_strings.rs | 278 +++++++----------- 1 file changed, 107 insertions(+), 171 deletions(-) diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index bfd715066..b805346d0 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -490,126 +490,6 @@ mod test { Ok(()) } - #[test] - fn test_spark_substring_string_array_with_zero_and_negative_pos() -> Result<()> { - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - Some("数据fusion".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(0_i64)), - ColumnarValue::Scalar(ScalarValue::from(3_i64)), - ])?; - let s = r.into_array(3)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some("abc"), Some("数据f"), None] - ); - - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - Some("数据fusion".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(-3_i64)), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ])?; - let s = r.into_array(3)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some("de"), Some("io"), None] - ); - Ok(()) - } - - #[test] - fn test_spark_substring_string_array_with_non_positive_len() -> Result<()> { - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ColumnarValue::Scalar(ScalarValue::from(0_i64)), - ])?; - let s = r.into_array(2)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some(""), None] - ); - Ok(()) - } - - #[test] - fn test_spark_substring_string_array_with_clamped_ranges() -> Result<()> { - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - Some("".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(10_i64)), - ColumnarValue::Scalar(ScalarValue::from(3_i64)), - ])?; - let s = r.into_array(3)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some(""), Some(""), None] - ); - - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - Some("".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ColumnarValue::Scalar(ScalarValue::from(100_i64)), - ])?; - let s = r.into_array(3)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some("bcdef"), Some(""), None] - ); - Ok(()) - } - - #[test] - fn test_spark_substring_string_array_with_extreme_args() -> Result<()> { - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - Some("".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), - ColumnarValue::Scalar(ScalarValue::from(3_i64)), - ])?; - let s = r.into_array(3)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some("abc"), Some(""), None] - ); - - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ - Some("abcdef".to_string()), - Some("".to_string()), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), - ])?; - let s = r.into_array(3)?; - assert_eq!( - as_string_array(&s)?.into_iter().collect::>(), - vec![Some("bcdef"), Some(""), None] - ); - Ok(()) - } - #[test] fn test_spark_substring_binary_array() -> Result<()> { let r = spark_substring(&vec![ @@ -630,59 +510,115 @@ mod test { } #[test] - fn test_spark_substring_binary_array_with_clamped_ranges() -> Result<()> { - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ - Some(&[1_u8, 2, 3, 4, 5][..]), - Some(&[][..]), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(10_i64)), - ColumnarValue::Scalar(ScalarValue::from(3_i64)), - ])?; - let b = r.into_array(3)?; - assert_eq!( - as_binary_array(&b)?.iter().collect::>(), - vec![Some(&[][..]), Some(&[][..]), None] - ); - - let r = spark_substring(&vec![ - ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ - Some(&[1_u8, 2, 3, 4, 5][..]), - Some(&[][..]), - None, - ]))), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), - ])?; - let b = r.into_array(3)?; - assert_eq!( - as_binary_array(&b)?.iter().collect::>(), - vec![Some(&[2_u8, 3, 4, 5][..]), Some(&[][..]), None] - ); - Ok(()) - } + fn test_spark_substring_edge_cases() -> Result<()> { + for (name, pos, len, expected) in [ + ( + "zero pos", + 0_i64, + 3_i64, + vec![Some("abc"), Some("数据f"), None], + ), + ( + "negative pos", + -3_i64, + 2_i64, + vec![Some("de"), Some("io"), None], + ), + ("zero len", 2_i64, 0_i64, vec![Some(""), Some(""), None]), + ( + "pos past end", + 10_i64, + 3_i64, + vec![Some(""), Some(""), None], + ), + ( + "len past end", + 2_i64, + 100_i64, + vec![Some("bcdef"), Some("据fusion"), None], + ), + ( + "i64 min pos", + i64::MIN, + 3_i64, + vec![Some("abc"), Some("数据f"), None], + ), + ( + "i64 max len", + 2_i64, + i64::MAX, + vec![Some("bcdef"), Some("据fusion"), None], + ), + ] { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("数据fusion".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(pos)), + ColumnarValue::Scalar(ScalarValue::from(len)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + expected, + "string array case: {name}" + ); + } - #[test] - fn test_spark_substring_scalar_with_edge_cases() -> Result<()> { - let r = spark_substring(&vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcdef".to_string()))), - ColumnarValue::Scalar(ScalarValue::from(i64::MIN)), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ])?; - match r { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => assert_eq!(value, "ab"), - other => df_execution_err!("Expected scalar Utf8 substring, got: {:?}", other)?, + for (name, pos, len, expected) in [ + ( + "zero len", + 2_i64, + 0_i64, + vec![Some(&[][..]), Some(&[][..]), None], + ), + ( + "pos past end", + 10_i64, + 3_i64, + vec![Some(&[][..]), Some(&[][..]), None], + ), + ( + "i64 max len", + 2_i64, + i64::MAX, + vec![Some(&[2_u8, 3, 4, 5][..]), Some(&[][..]), None], + ), + ] { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ + Some(&[1_u8, 2, 3, 4, 5][..]), + Some(&[][..]), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(pos)), + ColumnarValue::Scalar(ScalarValue::from(len)), + ])?; + let b = r.into_array(3)?; + assert_eq!( + as_binary_array(&b)?.iter().collect::>(), + expected, + "binary array case: {name}" + ); } - let r = spark_substring(&vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("".to_string()))), - ColumnarValue::Scalar(ScalarValue::from(2_i64)), - ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), - ])?; - match r { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => assert_eq!(value, ""), - other => df_execution_err!("Expected empty scalar Utf8 substring, got: {:?}", other)?, + for (name, input, pos, len, expected) in [ + ("i64 min pos", "abcdef", i64::MIN, 2_i64, "ab"), + ("empty string", "", 2_i64, i64::MAX, ""), + ] { + let r = spark_substring(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(input.to_string()))), + ColumnarValue::Scalar(ScalarValue::from(pos)), + ColumnarValue::Scalar(ScalarValue::from(len)), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { + assert_eq!(value, expected, "scalar string case: {name}") + } + other => df_execution_err!("Expected scalar Utf8 substring, got: {:?}", other)?, + } } let r = spark_substring(&vec![ @@ -692,7 +628,7 @@ mod test { ])?; match r { ColumnarValue::Scalar(ScalarValue::Binary(Some(value))) => { - assert_eq!(value, vec![2_u8, 3, 4, 5]) + assert_eq!(value, vec![2_u8, 3, 4, 5], "scalar binary case") } other => df_execution_err!("Expected scalar Binary substring, got: {:?}", other)?, }