diff --git a/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 052cca5d1..01e67e004 100644 --- a/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark31/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -23,3 +23,4 @@ class AuronSparkTestSettings extends SparkTestSettings { override def getOverwriteSQLQueryTests: Set[String] = Set.empty } } + diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 55672fdd4..5e96aa646 100644 --- a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -25,8 +25,6 @@ class AuronSparkTestSettings extends SparkTestSettings { } enableSuite[AuronStringFunctionsSuite] - // See https://github.com/apache/auron/issues/1724 - .exclude("string / binary substring function") enableSuite[AuronDataFrameAggregateSuite] // See https://github.com/apache/auron/issues/1840 diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 7eb1f63a6..3916d63b3 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -75,6 +75,7 @@ pub fn create_auron_ext_function( "Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws), "Spark_StringLower" => Arc::new(spark_strings::string_lower), "Spark_StringUpper" => Arc::new(spark_strings::string_upper), + "Spark_Substring" => Arc::new(spark_strings::spark_substring), "Spark_InitCap" => Arc::new(spark_initcap::string_initcap), "Spark_Year" => Arc::new(spark_dates::spark_year), "Spark_Month" => Arc::new(spark_dates::spark_month), diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 43b1f136f..b805346d0 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -16,13 +16,15 @@ use std::sync::Arc; use arrow::{ - array::{Array, ArrayRef, AsArray, ListArray, ListBuilder, StringArray, StringBuilder}, + array::{ + Array, ArrayRef, AsArray, BinaryArray, ListArray, ListBuilder, StringArray, StringBuilder, + }, datatypes::DataType, }; use datafusion::{ common::{ Result, ScalarValue, - cast::{as_int32_array, as_list_array, as_string_array}, + cast::{as_binary_array, as_int32_array, as_list_array, as_string_array}, }, physical_plan::ColumnarValue, }; @@ -114,6 +116,80 @@ pub fn string_split(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(splitted_builder.finish()))) } +pub fn spark_substring(args: &[ColumnarValue]) -> Result { + let pos = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => *value, + _ => df_execution_err!("substring pos only supports literal int64")?, + }; + let len = match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => *value, + _ => df_execution_err!("substring len only supports literal int64")?, + }; + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let result: StringArray = as_string_array(array)? + .iter() + .map(|value| { + value.map(|value| { + let chars = value.chars().collect::>(); + let (start, end) = substring_range(chars.len(), pos, len); + chars[start..end].iter().collect::() + }) + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + DataType::Binary => { + let result: BinaryArray = as_binary_array(array)? + .iter() + .map(|value| { + value.map(|value| { + let (start, end) = substring_range(value.len(), pos, len); + &value[start..end] + }) + }) + .collect(); + Ok(ColumnarValue::Array(Arc::new(result))) + } + other => df_execution_err!("substring only supports utf8 or binary, got {other:?}"), + }, + ColumnarValue::Scalar(ScalarValue::Utf8(value)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(value.as_ref().map(|value| { + let chars = value.chars().collect::>(); + let (start, end) = substring_range(chars.len(), pos, len); + chars[start..end].iter().collect::() + })), + )), + ColumnarValue::Scalar(ScalarValue::Binary(value)) => Ok(ColumnarValue::Scalar( + ScalarValue::Binary(value.as_ref().map(|value| { + let (start, end) = substring_range(value.len(), pos, len); + value[start..end].to_vec() + })), + )), + other => df_execution_err!("substring only supports utf8 or binary, got {:?}", other), + } +} + +fn substring_range(total_len: usize, pos: i64, len: i64) -> (usize, usize) { + if len <= 0 { + return (0, 0); + } + + let total_len_i64 = total_len as i64; + let raw_start = if pos > 0 { + pos - 1 + } else if pos < 0 { + total_len_i64.saturating_add(pos) + } else { + 0 + }; + let start = raw_start.clamp(0, total_len_i64) as usize; + let end = (start as i64).saturating_add(len).clamp(0, total_len_i64) as usize; + (start, end) +} + /// concat() function compatible with spark (returns null if any param is null) /// concat('abcde', 2, 22) = 'abcde222 /// concat('abcde', 2, NULL, 22) = NULL @@ -322,19 +398,19 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { mod test { use std::sync::Arc; - use arrow::array::{Int32Array, ListBuilder, StringArray, StringBuilder}; + use arrow::array::{BinaryArray, Int32Array, ListBuilder, StringArray, StringBuilder}; use datafusion::{ common::{ Result, ScalarValue, - cast::{as_list_array, as_string_array}, + cast::{as_binary_array, as_list_array, as_string_array}, }, physical_plan::ColumnarValue, }; use datafusion_ext_commons::df_execution_err; use crate::spark_strings::{ - string_concat, string_concat_ws, string_lower, string_repeat, string_space, string_split, - string_upper, + spark_substring, string_concat, string_concat_ws, string_lower, string_repeat, + string_space, string_split, string_upper, }; #[test] @@ -395,6 +471,170 @@ mod test { } } + #[test] + fn test_spark_substring_string_array() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("数据fusion".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + vec![Some("bcd"), Some("据fu"), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_binary_array() -> Result<()> { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ + Some(&[1_u8, 2, 3, 4, 5][..]), + Some(&[9_u8, 8][..]), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(3_i64)), + ])?; + let b = r.into_array(3)?; + assert_eq!( + as_binary_array(&b)?.iter().collect::>(), + vec![Some(&[2_u8, 3, 4][..]), Some(&[8_u8][..]), None] + ); + Ok(()) + } + + #[test] + fn test_spark_substring_edge_cases() -> Result<()> { + for (name, pos, len, expected) in [ + ( + "zero pos", + 0_i64, + 3_i64, + vec![Some("abc"), Some("数据f"), None], + ), + ( + "negative pos", + -3_i64, + 2_i64, + vec![Some("de"), Some("io"), None], + ), + ("zero len", 2_i64, 0_i64, vec![Some(""), Some(""), None]), + ( + "pos past end", + 10_i64, + 3_i64, + vec![Some(""), Some(""), None], + ), + ( + "len past end", + 2_i64, + 100_i64, + vec![Some("bcdef"), Some("据fusion"), None], + ), + ( + "i64 min pos", + i64::MIN, + 3_i64, + vec![Some("abc"), Some("数据f"), None], + ), + ( + "i64 max len", + 2_i64, + i64::MAX, + vec![Some("bcdef"), Some("据fusion"), None], + ), + ] { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("abcdef".to_string()), + Some("数据fusion".to_string()), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(pos)), + ColumnarValue::Scalar(ScalarValue::from(len)), + ])?; + let s = r.into_array(3)?; + assert_eq!( + as_string_array(&s)?.into_iter().collect::>(), + expected, + "string array case: {name}" + ); + } + + for (name, pos, len, expected) in [ + ( + "zero len", + 2_i64, + 0_i64, + vec![Some(&[][..]), Some(&[][..]), None], + ), + ( + "pos past end", + 10_i64, + 3_i64, + vec![Some(&[][..]), Some(&[][..]), None], + ), + ( + "i64 max len", + 2_i64, + i64::MAX, + vec![Some(&[2_u8, 3, 4, 5][..]), Some(&[][..]), None], + ), + ] { + let r = spark_substring(&vec![ + ColumnarValue::Array(Arc::new(BinaryArray::from_iter(vec![ + Some(&[1_u8, 2, 3, 4, 5][..]), + Some(&[][..]), + None, + ]))), + ColumnarValue::Scalar(ScalarValue::from(pos)), + ColumnarValue::Scalar(ScalarValue::from(len)), + ])?; + let b = r.into_array(3)?; + assert_eq!( + as_binary_array(&b)?.iter().collect::>(), + expected, + "binary array case: {name}" + ); + } + + for (name, input, pos, len, expected) in [ + ("i64 min pos", "abcdef", i64::MIN, 2_i64, "ab"), + ("empty string", "", 2_i64, i64::MAX, ""), + ] { + let r = spark_substring(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(input.to_string()))), + ColumnarValue::Scalar(ScalarValue::from(pos)), + ColumnarValue::Scalar(ScalarValue::from(len)), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { + assert_eq!(value, expected, "scalar string case: {name}") + } + other => df_execution_err!("Expected scalar Utf8 substring, got: {:?}", other)?, + } + } + + let r = spark_substring(&vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![1_u8, 2, 3, 4, 5]))), + ColumnarValue::Scalar(ScalarValue::from(2_i64)), + ColumnarValue::Scalar(ScalarValue::from(i64::MAX)), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Binary(Some(value))) => { + assert_eq!(value, vec![2_u8, 3, 4, 5], "scalar binary case") + } + other => df_execution_err!("Expected scalar Binary substring, got: {:?}", other)?, + } + Ok(()) + } + #[test] fn test_string_repeat() -> Result<()> { // positive case diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index dbe1781ee..be2f1a640 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -1025,10 +1025,10 @@ object NativeConverters extends Logging { if pos.asInstanceOf[Int] > 0 && len.asInstanceOf[Int] >= 0 => val longPos = pos.asInstanceOf[Int].toLong val longLen = len.asInstanceOf[Int].toLong - buildScalarFunction( - pb.ScalarFunction.Substr, + buildExtScalarFunction( + "Spark_Substring", str :: Literal(longPos) :: Literal(longLen) :: Nil, - StringType) + str.dataType) case StringSpace(n) => buildExtScalarFunction("Spark_StringSpace", n :: Nil, StringType)