diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index c92d434e34abe..6ed6e78852c14 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -75,3 +75,20 @@ pub fn is_datetime(dt: &DataType) -> bool { DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _) ) } + +pub fn is_binary(dt: &DataType) -> bool { + matches!( + dt, + DataType::Binary + | DataType::LargeBinary + | DataType::FixedSizeBinary(_) + | DataType::BinaryView + ) +} + +pub fn is_string(dt: &DataType) -> bool { + matches!( + dt, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} diff --git a/datafusion/functions/src/binaries.rs b/datafusion/functions/src/binaries.rs new file mode 100644 index 0000000000000..6d325ef31bdc3 --- /dev/null +++ b/datafusion/functions/src/binaries.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::strings::{ColumnarValueRef, ConcatBuilder}; +use arrow::array::{ + Array, ArrayDataBuilder, ArrayRef, BinaryViewArray, GenericBinaryArray, + OffsetSizeTrait, make_view, +}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, NullBuffer, ScalarBuffer}; +use datafusion_common::{Result, exec_datafusion_err, exec_err, internal_err}; +use std::marker::PhantomData; +use std::sync::Arc; + +pub(crate) struct ConcatGenericBinaryBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, + _phantom: PhantomData, +} +pub(crate) type ConcatBinaryBuilder = ConcatGenericBinaryBuilder; +pub(crate) type ConcatLargeBinaryBuilder = ConcatGenericBinaryBuilder; + +impl ConcatGenericBinaryBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let capacity = item_capacity + .checked_add(1) + .map(|i| i.saturating_mul(size_of::())) + .expect("capacity integer overflow"); + + let mut offsets_buffer = MutableBuffer::with_capacity(capacity); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(O::usize_as(0)) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + _phantom: PhantomData, + } + } +} + +impl ConcatBuilder + for ConcatGenericBinaryBuilder +{ + fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) -> Result<()> { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NullableLargeBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NullableBinaryViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NullableFixedSizeBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NonNullableBinaryArray(array) => { + self.value_buffer.extend_from_slice(array.value(i)); + } + ColumnarValueRef::NonNullableLargeBinaryArray(array) => { + self.value_buffer.extend_from_slice(array.value(i)); + } + ColumnarValueRef::NonNullableBinaryViewArray(array) => { + self.value_buffer.extend_from_slice(array.value(i)); + } + ColumnarValueRef::NonNullableFixedSizeBinaryArray(array) => { + self.value_buffer.extend_from_slice(array.value(i)); + } + _ => { + return exec_err!( + "concat: unexpected column type for binary builder: {column:?}" + ); + } + } + Ok(()) + } + + fn append_offset(&mut self) -> Result<()> { + let next_offset: O = O::from_usize(self.value_buffer.len()) + .ok_or_else(|| exec_datafusion_err!("byte array offset overflow"))?; + self.offsets_buffer.push(next_offset); + Ok(()) + } + + /// Finalize the builder into a concrete [`GenericBinaryArray`]. + /// + /// # Errors + /// + /// Returns an error when: + /// + /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. + fn finish(self, null_buffer: Option) -> Result { + let row_count = self.offsets_buffer.len() / size_of::() - 1; + if let Some(ref null_buffer) = null_buffer + && null_buffer.len() != row_count + { + return internal_err!( + "Null buffer and offsets buffer must be the same length" + ); + } + let array_builder = ArrayDataBuilder::new(GenericBinaryArray::::DATA_TYPE) + .len(row_count) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + let array = GenericBinaryArray::::from(array_data); + Ok(Arc::new(array)) + } +} + +/// Builder used by `concat`/`concat_ws` to assemble a [`BinaryViewArray`] one +/// row at a time from multiple input columns. +/// +/// Each row is written via repeated `write` calls (one per input +/// fragment) followed by a single `append_offset` to commit the row +/// as a single binary view. The output null buffer is supplied by the caller +/// at `finish` time, avoiding per-row NULL handling work. +/// +pub(crate) struct ConcatBinaryViewBuilder { + views: Vec, + data: Vec, + block: Vec, +} + +impl ConcatBinaryViewBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + Self { + views: Vec::with_capacity(item_capacity), + data: Vec::with_capacity(data_capacity), + block: vec![], + } + } +} + +impl ConcatBuilder for ConcatBinaryViewBuilder { + fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) -> Result<()> { + match column { + ColumnarValueRef::Scalar(s) => { + self.block.extend_from_slice(s); + } + ColumnarValueRef::NullableBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NullableLargeBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NullableBinaryViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NonNullableBinaryArray(array) => { + self.block.extend_from_slice(array.value(i)); + } + ColumnarValueRef::NonNullableLargeBinaryArray(array) => { + self.block.extend_from_slice(array.value(i)); + } + ColumnarValueRef::NonNullableBinaryViewArray(array) => { + self.block.extend_from_slice(array.value(i)); + } + ColumnarValueRef::NullableFixedSizeBinaryArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.extend_from_slice(array.value(i)); + } + } + ColumnarValueRef::NonNullableFixedSizeBinaryArray(array) => { + self.block.extend_from_slice(array.value(i)); + } + _ => { + return exec_err!( + "concat: unexpected column type for binary view builder: {column:?}" + ); + } + } + Ok(()) + } + + /// Finalizes the current row by converting the accumulated data into a + /// StringView and appending it to the views buffer. + fn append_offset(&mut self) -> Result<()> { + let v = &self.block; + if v.len() > 12 { + let offset: u32 = self + .data + .len() + .try_into() + .map_err(|_| exec_datafusion_err!("byte array offset overflow"))?; + self.data.extend_from_slice(v); + self.views.push(make_view(v, 0, offset)); + } else { + self.views.push(make_view(v, 0, 0)); + } + + self.block.clear(); + Ok(()) + } + + /// Finalize the builder into a concrete [`BinaryViewArray`]. + /// + /// # Errors + /// + /// Returns an error when: + /// + /// - the provided `null_buffer` length does not match the row count. + fn finish(self, null_buffer: Option) -> Result { + if let Some(ref nulls) = null_buffer + && nulls.len() != self.views.len() + { + return internal_err!( + "Null buffer length ({}) must match row count ({})", + nulls.len(), + self.views.len() + ); + } + + let buffers: Vec = if self.data.is_empty() { + vec![] + } else { + vec![Buffer::from(self.data)] + }; + + // SAFETY: views were constructed with correct lengths, offsets, and + // prefixes. + let array = unsafe { + BinaryViewArray::new_unchecked( + ScalarBuffer::from(self.views), + buffers, + null_buffer, + ) + }; + Ok(Arc::new(array)) + } +} diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 7e753d7f35eb3..14d1743770883 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -141,6 +141,7 @@ make_stub_package!(unicode, "unicode_expressions"); #[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] pub mod planner; +pub mod binaries; pub mod strings; pub mod utils; diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index b10db23472c99..83beeacfd74c8 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -15,22 +15,23 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, as_largestring_array}; -use arrow::datatypes::DataType; -use datafusion_expr::sort_properties::ExprProperties; -use std::sync::Arc; - +use crate::binaries::{ + ConcatBinaryBuilder, ConcatBinaryViewBuilder, ConcatLargeBinaryBuilder, +}; use crate::string::concat; use crate::strings::{ - ColumnarValueRef, ConcatLargeStringBuilder, ConcatStringBuilder, - ConcatStringViewBuilder, + ColumnarValueRef, ConcatBuilder, ConcatLargeStringBuilder, ConcatStringBuilder, + ConcatStringViewBuilder, widest_binary_type, widest_string_type, }; -use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array}; +use arrow::array::Array; +use arrow::datatypes::DataType; use datafusion_common::{ Result, ScalarValue, exec_datafusion_err, internal_err, plan_err, }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::type_coercion::{is_binary, is_string}; use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -67,27 +68,19 @@ impl Default for ConcatFunc { impl ConcatFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::variadic( - vec![Utf8View, Utf8, LargeUtf8, Binary], - Volatility::Immutable, - ), + // Use `Signature::UserDefined` to allow different argument types. + // `Variadic` requires every argument to be coerced to the same string type, + // so the UDF cannot distinguish between binary and string inputs. + signature: Signature::user_defined(Volatility::Immutable), } } } -fn deduce_return_type(arg_types: &[DataType]) -> DataType { - use DataType::*; - if arg_types.contains(&Utf8View) { - Utf8View - } else if arg_types.contains(&LargeUtf8) { - LargeUtf8 - } else { - Utf8 - } -} - +// Logic is matched with pipe operator in the following table. +// Support only string + string concatenation, +// or binary + binary concatenation. +// Mixed string + binary concatenation is rejected, impl ScalarUDFImpl for ConcatFunc { fn name(&self) -> &str { "concat" @@ -97,9 +90,18 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } - /// Match the return type to the input types to avoid unnecessary casts. On + /// Coerce all arguments to the widest type within the binary / string family + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.is_empty() { + plan_err!("concat does not support zero arguments") + } else { + coerce_arg_types(arg_types) + } + } + /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid /// potential overflow on LargeUtf8 input. + /// For binaries, use the similar hierarchy fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(deduce_return_type(arg_types)) } @@ -112,6 +114,18 @@ impl ScalarUDFImpl for ConcatFunc { let arg_types: Vec = args.iter().map(|c| c.data_type()).collect(); let return_datatype = deduce_return_type(&arg_types); + let with_binary = arg_types.iter().any(is_binary); + let with_string = arg_types.iter().any(|t| { + matches!(t, DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8) + }); + + if with_binary && with_string { + return plan_err!( + "{} does not support mixed string and binary inputs", + &self.name() + ); + } + let array_len = args.iter().find_map(|x| match x { ColumnarValue::Array(array) => Some(array.len()), _ => None, @@ -126,7 +140,16 @@ impl ScalarUDFImpl for ConcatFunc { }; if let ScalarValue::Binary(Some(value)) = scalar { values.push(value); + } else if let ScalarValue::LargeBinary(Some(value)) = scalar { + values.push(value); + } else if let ScalarValue::BinaryView(Some(value)) = scalar { + values.push(value); + } else if let ScalarValue::FixedSizeBinary(_, Some(value)) = scalar { + values.push(value); + } else if scalar.is_null() { + // null binary scalar: skip (consistent with null string behaviour) } else { + // String case match scalar.try_as_str() { Some(Some(v)) => values.push(v.as_bytes()), Some(None) => {} // null literal @@ -138,20 +161,42 @@ impl ScalarUDFImpl for ConcatFunc { } } let concat_bytes = values.concat(); - let result = std::str::from_utf8(&concat_bytes) - .map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))? - .to_string(); return match return_datatype { DataType::Utf8View => { + let result = std::str::from_utf8(&concat_bytes) + .map_err(|_| { + exec_datafusion_err!("invalid UTF-8 in binary literal") + })? + .to_string(); Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) } DataType::Utf8 => { + let result = std::str::from_utf8(&concat_bytes) + .map_err(|_| { + exec_datafusion_err!("invalid UTF-8 in binary literal") + })? + .to_string(); Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) } DataType::LargeUtf8 => { + let result = std::str::from_utf8(&concat_bytes) + .map_err(|_| { + exec_datafusion_err!("invalid UTF-8 in binary literal") + })? + .to_string(); Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) } + DataType::Binary => Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some( + concat_bytes, + )))), + // Serves LargeBinary and FixedSizeBinary inputs + DataType::LargeBinary => Ok(ColumnarValue::Scalar( + ScalarValue::LargeBinary(Some(concat_bytes)), + )), + DataType::BinaryView => Ok(ColumnarValue::Scalar( + ScalarValue::BinaryView(Some(concat_bytes)), + )), other => { plan_err!("Concat function does not support datatype of {other}") } @@ -164,121 +209,46 @@ impl ScalarUDFImpl for ConcatFunc { let mut columns = Vec::with_capacity(args.len()); for arg in &args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); - } - } - ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => { - if let Some(b) = maybe_value { - // data_size is a capacity hint, so doesn't matter if it is chars or bytes - data_size += b.len() * len; - columns.push(ColumnarValueRef::Scalar(b.as_slice())); - } - } - ColumnarValue::Array(array) => { - match array.data_type() { - DataType::Utf8 => { - let string_array = as_string_array(array)?; - - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) - }; - columns.push(column); - } - DataType::LargeUtf8 => { - let string_array = as_largestring_array(array); - - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableLargeStringArray(string_array) - } else { - ColumnarValueRef::NonNullableLargeStringArray( - string_array, - ) - }; - columns.push(column); - } - DataType::Utf8View => { - let string_array = as_string_view_array(array)?; - - // This is an estimate; in particular, it will - // undercount arrays of short strings (<= 12 bytes). - data_size += string_array.total_buffer_bytes_used(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableStringViewArray(string_array) - } else { - ColumnarValueRef::NonNullableStringViewArray(string_array) - }; - columns.push(column); - } - DataType::Binary => { - let string_array = as_binary_array(array)?; - - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableBinaryArray(string_array) - } else { - ColumnarValueRef::NonNullableBinaryArray(string_array) - }; - columns.push(column); - } - other => { - return plan_err!( - "Input was {other} which is not a supported datatype for concat function" - ); - } - }; - } - _ => unreachable!("concat"), + if let Some(column) = + ColumnarValueRef::from_columnar_value(arg, &mut data_size, len, 1, false)? + { + columns.push(column); } } match return_datatype { - DataType::Utf8 => { - let mut builder = ConcatStringBuilder::with_capacity(len, data_size); - for i in 0..len { - columns - .iter() - .for_each(|column| builder.write::(column, i)); - builder.append_offset()?; - } - - let string_array = builder.finish(None)?; - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::Utf8View => { - let mut builder = ConcatStringViewBuilder::with_capacity(len, data_size); - for i in 0..len { - columns - .iter() - .for_each(|column| builder.write::(column, i)); - builder.append_offset()?; - } - - let string_array = builder.finish(None)?; - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - DataType::LargeUtf8 => { - let mut builder = ConcatLargeStringBuilder::with_capacity(len, data_size); - for i in 0..len { - columns - .iter() - .for_each(|column| builder.write::(column, i)); - builder.append_offset()?; - } - - let string_array = builder.finish(None)?; - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - _ => unreachable!(), + DataType::Utf8 => build_concat( + ConcatStringBuilder::with_capacity(len, data_size), + &columns, + len, + ), + DataType::Utf8View => build_concat( + ConcatStringViewBuilder::with_capacity(len, data_size), + &columns, + len, + ), + DataType::LargeUtf8 => build_concat( + ConcatLargeStringBuilder::with_capacity(len, data_size), + &columns, + len, + ), + DataType::Binary => build_concat( + ConcatBinaryBuilder::with_capacity(len, data_size), + &columns, + len, + ), + // Serves LargeBinary and FixedSizeBinary inputs + DataType::LargeBinary => build_concat( + ConcatLargeBinaryBuilder::with_capacity(len, data_size), + &columns, + len, + ), + DataType::BinaryView => build_concat( + ConcatBinaryViewBuilder::with_capacity(len, data_size), + &columns, + len, + ), + _ => unreachable!("concat"), } } @@ -307,7 +277,67 @@ impl ScalarUDFImpl for ConcatFunc { } } +pub(crate) fn deduce_return_type(arg_types: &[DataType]) -> DataType { + use DataType::*; + if arg_types.contains(&BinaryView) { + BinaryView + } else if arg_types.contains(&LargeBinary) { + LargeBinary + } else if arg_types.contains(&Binary) { + Binary + } else if arg_types.iter().any(|dt| matches!(dt, FixedSizeBinary(_))) { + LargeBinary + } else if arg_types.contains(&Utf8View) { + Utf8View + } else if arg_types.contains(&LargeUtf8) { + LargeUtf8 + } else { + Utf8 + } +} + +/// Coerce all arguments to the widest type within the binary / string family +pub(crate) fn coerce_arg_types(arg_types: &[DataType]) -> Result> { + let has_binary = arg_types.iter().any(is_binary); + let has_string = arg_types.iter().any(is_string); + if has_binary && has_string { + plan_err!("function does not support mixed string and binary inputs") + } else if has_binary { + Ok(vec![widest_binary_type(arg_types); arg_types.len()]) + } else { + Ok(vec![widest_string_type(arg_types); arg_types.len()]) + } +} + +/// Build a `concats` output array using a generic [`ConcatBuilder`]. +fn build_concat( + mut builder: B, + columns: &[ColumnarValueRef], + len: usize, +) -> Result { + for i in 0..len { + for column in columns { + builder.write::(column, i)?; + } + builder.append_offset()?; + } + + let array = builder.finish(None)?; + Ok(ColumnarValue::Array(array)) +} + pub(crate) fn simplify_concat(args: Vec) -> Result { + // Skip simplification when binary literals are present, because it + // handles only strings + for arg in &args { + match arg { + Expr::Literal(dt, _) if is_binary(&dt.data_type()) => { + return Ok(ExprSimplifyResult::Original(args)); + } + _ => {} + } + } + let mut new_args = Vec::with_capacity(args.len()); let mut contiguous_scalar = "".to_string(); @@ -396,10 +426,13 @@ mod tests { use super::*; use crate::utils::test::test_function; use DataType::*; - use arrow::array::{ArrayRef, StringArray}; + use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, LargeBinaryArray, StringArray, + }; use arrow::array::{LargeStringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; + use std::sync::Arc; #[test] fn test_functions() -> Result<()> { @@ -471,38 +504,95 @@ mod tests { Utf8View, StringViewArray ); + Ok(()) + } + + #[test] + fn test_scalar_binary() -> Result<()> { test_function!( ConcatFunc::new(), vec![ ColumnarValue::Scalar(ScalarValue::Binary(Some( "Café".as_bytes().into() ))), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))), + ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))), ], - Ok(Some("Cafécc")), - &str, - Utf8, - StringArray + Ok(Some("Cafécc".as_bytes())), + &[u8], + Binary, + BinaryArray ); test_function!( ConcatFunc::new(), vec![ - ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from( - "Café".as_bytes() - )))), - ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))), + ColumnarValue::Scalar(ScalarValue::Binary(Some( + "Café".as_bytes().into() + ))), + ColumnarValue::Scalar(ScalarValue::LargeBinary(Some( + "cc".as_bytes().into() + ))), ], - Ok(Some("Cafécc")), - &str, - Utf8, - StringArray + Ok(Some("Cafécc".as_bytes())), + &[u8], + LargeBinary, + LargeBinaryArray + ); + test_function!( + ConcatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Binary(Some( + "Café".as_bytes().into() + ))), + ColumnarValue::Scalar(ScalarValue::BinaryView(Some( + "cc".as_bytes().into() + ))), + ], + Ok(Some("Cafécc".as_bytes())), + &[u8], + BinaryView, + BinaryViewArray + ); + test_function!( + ConcatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::BinaryView(Some( + "Café".as_bytes().into() + ))), + ColumnarValue::Scalar(ScalarValue::BinaryView(Some( + "cc".as_bytes().into() + ))), + ], + Ok(Some("Cafécc".as_bytes())), + &[u8], + BinaryView, + BinaryViewArray + ); + // Skip one Binary(None) + test_function!( + ConcatFunc::new(), + vec![ + ColumnarValue::Scalar(ScalarValue::Binary(None)), + ColumnarValue::Scalar(ScalarValue::Binary(Some(b"hello".to_vec()))), + ], + Ok(Some(b"hello".as_ref())), + &[u8], + Binary, + BinaryArray + ); + // Skip all Binary(None), producing an empty array + test_function!( + ConcatFunc::new(), + vec![ColumnarValue::Scalar(ScalarValue::Binary(None))], + Ok(Some(b"".as_ref())), + &[u8], + Binary, + BinaryArray ); Ok(()) } #[test] - fn concat() -> Result<()> { + fn test_array_string() -> Result<()> { let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); @@ -548,4 +638,55 @@ mod tests { } Ok(()) } + + #[test] + fn test_array_binary() -> Result<()> { + let c0 = ColumnarValue::Array(Arc::new(BinaryArray::from_vec(vec![ + b"foo", b"bar", b"baz", + ]))); + let c1 = ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(b",".to_vec()))); + let c2 = ColumnarValue::Array(Arc::new(BinaryArray::from_opt_vec(vec![ + Some(b"x"), + None, + Some(b"z"), + ]))); + let c3 = ColumnarValue::Scalar(ScalarValue::BinaryView(Some(b",".to_vec()))); + let c4 = ColumnarValue::Array(Arc::new(BinaryViewArray::from_iter(vec![ + Some(b"a"), + None, + Some(b"b"), + ]))); + let arg_fields = vec![ + Field::new("a", Binary, true), + Field::new("a", LargeBinary, true), + Field::new("a", Binary, true), + Field::new("a", BinaryView, true), + Field::new("a", BinaryView, true), + ] + .into_iter() + .map(Arc::new) + .collect::>(); + + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2, c3, c4], + arg_fields, + number_rows: 3, + return_field: Field::new("f", BinaryView, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatFunc::new().invoke_with_args(args)?; + let expected = Arc::new(BinaryViewArray::from_iter(vec![ + Some(b"foo,x,a".to_vec()), + Some(b"bar,,".to_vec()), + Some(b"baz,z,b".to_vec()), + ])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + Ok(()) + } } diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 2c2d4bd42165b..1d1d15b5cd5a8 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -16,23 +16,22 @@ // under the License. use arrow::array::Array; -use std::sync::Arc; - use arrow::datatypes::DataType; +use crate::binaries::{ + ConcatBinaryBuilder, ConcatBinaryViewBuilder, ConcatLargeBinaryBuilder, +}; use crate::string::concat; -use crate::string::concat::simplify_concat; +use crate::string::concat::{coerce_arg_types, deduce_return_type, simplify_concat}; use crate::string::concat_ws; use crate::strings::{ - ColumnarValueRef, ConcatLargeStringBuilder, ConcatStringBuilder, + ColumnarValueRef, ConcatBuilder, ConcatLargeStringBuilder, ConcatStringBuilder, ConcatStringViewBuilder, }; -use datafusion_common::cast::{ - as_large_string_array, as_string_array, as_string_view_array, -}; use datafusion_common::{Result, ScalarValue, exec_err, internal_err, plan_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext}; +use datafusion_expr::type_coercion::is_binary; use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature}; use datafusion_macros::user_doc; @@ -76,12 +75,11 @@ impl Default for ConcatWsFunc { impl ConcatWsFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::variadic( - vec![Utf8View, Utf8, LargeUtf8], - Volatility::Immutable, - ), + // Use `Signature::UserDefined` to allow different argument types. + // `Variadic` requires every argument to be coerced to the same string type, + // so the UDF cannot distinguish between binary and string inputs. + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -95,20 +93,23 @@ impl ScalarUDFImpl for ConcatWsFunc { &self.signature } - /// Match the return type to the input types to avoid unnecessary casts. On - /// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid - /// potential overflow on LargeUtf8 input. - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - if arg_types.contains(&Utf8View) { - Ok(Utf8View) - } else if arg_types.contains(&LargeUtf8) { - Ok(LargeUtf8) + /// Coerce all arguments to the widest type within the binary / string family + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 { + plan_err!( + "concat_ws expects at least 2 arguments, got {}", + arg_types.len() + ) } else { - Ok(Utf8) + coerce_arg_types(arg_types) } } + /// Match the return type to the input types. Delegates to `concat` implementation. + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(deduce_return_type(arg_types)) + } + /// Concatenates all but the first argument, with separators. The first /// argument is used as the separator string, and should not be NULL. Other /// NULL arguments are ignored. @@ -123,14 +124,18 @@ impl ScalarUDFImpl for ConcatWsFunc { ); } - let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View) - { - DataType::Utf8View - } else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) { - DataType::LargeUtf8 - } else { - DataType::Utf8 - }; + let arg_types: Vec = args.iter().map(|c| c.data_type()).collect(); + let return_datatype = deduce_return_type(&arg_types); + + let with_binary = arg_types.iter().any(is_binary); + let with_string = arg_types.iter().any(|t| { + matches!(t, DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8) + }); + if with_binary && with_string { + return plan_err!( + "concat_ws does not support mixed string and binary inputs" + ); + } let array_len = args.iter().find_map(|x| match x { ColumnarValue::Array(array) => Some(array.len()), @@ -142,47 +147,101 @@ impl ScalarUDFImpl for ConcatWsFunc { let ColumnarValue::Scalar(scalar) = &args[0] else { unreachable!() }; - let sep = match scalar.try_as_str() { - Some(Some(s)) => s, - Some(None) => { - // null literal string - return match return_datatype { - DataType::Utf8View => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) + + return if with_binary { + // Binary scalar path + let sep_bytes: &[u8] = match scalar { + ScalarValue::Binary(Some(v)) + | ScalarValue::LargeBinary(Some(v)) + | ScalarValue::BinaryView(Some(v)) => v.as_slice(), + ScalarValue::FixedSizeBinary(_, Some(v)) => v.as_slice(), + scalar if scalar.is_null() => { + return Ok(null_scalar(&return_datatype)); + } + other => { + return internal_err!("Expected binary separator, got {other:?}"); + } + }; + + let mut values: Vec<&[u8]> = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + let ColumnarValue::Scalar(s) = arg else { + unreachable!() + }; + match s { + ScalarValue::Binary(Some(v)) + | ScalarValue::LargeBinary(Some(v)) + | ScalarValue::BinaryView(Some(v)) => values.push(v.as_slice()), + ScalarValue::FixedSizeBinary(_, Some(v)) => { + values.push(v.as_slice()) } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) + // skip null + scalar if scalar.is_null() => {} + other => { + return internal_err!("Expected binary value, got {other:?}"); } - _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - }; + } } - None => return internal_err!("Expected string literal, got {scalar:?}"), - }; + let result = values.join(sep_bytes); - let mut values = Vec::with_capacity(args.len() - 1); - for arg in &args[1..] { - let ColumnarValue::Scalar(scalar) = arg else { - unreachable!() - }; - - match scalar.try_as_str() { - Some(Some(v)) => values.push(v), - Some(None) => {} // null literal string + match return_datatype { + DataType::Binary => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(result)))) + } + DataType::LargeBinary => Ok(ColumnarValue::Scalar( + ScalarValue::LargeBinary(Some(result)), + )), + DataType::BinaryView => { + Ok(ColumnarValue::Scalar(ScalarValue::BinaryView(Some(result)))) + } + other => { + plan_err!("concat_ws does not support return type {other}") + } + } + } else { + // String scalar path + let sep = match scalar.try_as_str() { + Some(Some(s)) => s, + Some(None) => { + return Ok(null_scalar(&return_datatype)); + } None => { return internal_err!("Expected string literal, got {scalar:?}"); } - } - } - let result = values.join(sep); + }; + + let mut values = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + let ColumnarValue::Scalar(scalar) = arg else { + unreachable!() + }; - return match return_datatype { - DataType::Utf8View => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + match scalar.try_as_str() { + Some(Some(v)) => values.push(v), + Some(None) => {} // null literal string + None => { + return internal_err!( + "Expected string literal, got {scalar:?}" + ); + } + } } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + let result = values.join(sep); + + match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + DataType::Utf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + other => { + plan_err!("concat_ws does not support return type {other}") + } } - _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))), }; } @@ -190,190 +249,66 @@ impl ScalarUDFImpl for ConcatWsFunc { let len = array_len.unwrap(); let mut data_size = 0; - // parse sep - let sep = match &args[0] { - ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { - Some(Some(s)) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - Some(None) => { - return match return_datatype { - DataType::Utf8View => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) - } - _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - }; - } - None => { - return internal_err!("Expected string separator, got {scalar:?}"); - } - }, - ColumnarValue::Array(array) => match array.data_type() { - DataType::Utf8 => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); - if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) - } - } - DataType::LargeUtf8 => { - let string_array = as_large_string_array(array)?; - data_size += string_array.values().len() * (args.len() - 2); - if array.is_nullable() { - ColumnarValueRef::NullableLargeStringArray(string_array) - } else { - ColumnarValueRef::NonNullableLargeStringArray(string_array) - } - } - DataType::Utf8View => { - let string_array = as_string_view_array(array)?; - data_size += - string_array.total_buffer_bytes_used() * (args.len() - 2); - if array.is_nullable() { - ColumnarValueRef::NullableStringViewArray(string_array) - } else { - ColumnarValueRef::NonNullableStringViewArray(string_array) - } - } - other => { - return plan_err!( - "Input was {other} which is not a supported datatype for concat_ws separator" - ); - } - }, - }; + let sep_column = &args[0]; + + // A null scalar separator makes the entire result null for all rows. + if matches!(sep_column, ColumnarValue::Scalar(s) if s.is_null()) { + return Ok(null_scalar(&return_datatype)); + } + + let sep: ColumnarValueRef = ColumnarValueRef::from_columnar_value(sep_column, &mut data_size, len, args.len() - 2, true)? + .map(Ok) + .unwrap_or_else(|| plan_err!( + "Input {sep_column} which is not a supported datatype for concat_ws separator" + ))?; let mut columns = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); - } - } - ColumnarValue::Array(array) => { - match array.data_type() { - DataType::Utf8 => { - let string_array = as_string_array(array)?; - - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) - }; - columns.push(column); - } - DataType::LargeUtf8 => { - let string_array = as_large_string_array(array)?; - - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableLargeStringArray(string_array) - } else { - ColumnarValueRef::NonNullableLargeStringArray( - string_array, - ) - }; - columns.push(column); - } - DataType::Utf8View => { - let string_array = as_string_view_array(array)?; - - // This is an estimate; in particular, it will - // undercount arrays of short strings (<= 12 bytes). - data_size += string_array.total_buffer_bytes_used(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableStringViewArray(string_array) - } else { - ColumnarValueRef::NonNullableStringViewArray(string_array) - }; - columns.push(column); - } - other => { - return plan_err!( - "Input was {other} which is not a supported datatype for concat_ws function." - ); - } - }; - } - _ => unreachable!(), + if let Some(column) = + ColumnarValueRef::from_columnar_value(arg, &mut data_size, len, 1, false)? + { + columns.push(column); } } match return_datatype { - DataType::Utf8View => { - let mut builder = ConcatStringViewBuilder::with_capacity(len, data_size); - for i in 0..len { - if !sep.is_valid(i) { - builder.append_offset()?; - continue; - } - let mut first = true; - for column in &columns { - if column.is_valid(i) { - if !first { - builder.write::(&sep, i); - } - builder.write::(column, i); - first = false; - } - } - builder.append_offset()?; - } - Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?))) - } - DataType::LargeUtf8 => { - let mut builder = ConcatLargeStringBuilder::with_capacity(len, data_size); - for i in 0..len { - if !sep.is_valid(i) { - builder.append_offset()?; - continue; - } - let mut first = true; - for column in &columns { - if column.is_valid(i) { - if !first { - builder.write::(&sep, i); - } - builder.write::(column, i); - first = false; - } - } - builder.append_offset()?; - } - Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?))) - } - _ => { - let mut builder = ConcatStringBuilder::with_capacity(len, data_size); - for i in 0..len { - if !sep.is_valid(i) { - builder.append_offset()?; - continue; - } - let mut first = true; - for column in &columns { - if column.is_valid(i) { - if !first { - builder.write::(&sep, i); - } - builder.write::(column, i); - first = false; - } - } - builder.append_offset()?; - } - Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?))) - } + DataType::Utf8 => build_concat_ws( + ConcatStringBuilder::with_capacity(len, data_size), + &sep, + &columns, + len, + ), + DataType::LargeUtf8 => build_concat_ws( + ConcatLargeStringBuilder::with_capacity(len, data_size), + &sep, + &columns, + len, + ), + DataType::Utf8View => build_concat_ws( + ConcatStringViewBuilder::with_capacity(len, data_size), + &sep, + &columns, + len, + ), + DataType::Binary => build_concat_ws( + ConcatBinaryBuilder::with_capacity(len, data_size), + &sep, + &columns, + len, + ), + DataType::LargeBinary => build_concat_ws( + ConcatLargeBinaryBuilder::with_capacity(len, data_size), + &sep, + &columns, + len, + ), + DataType::BinaryView => build_concat_ws( + ConcatBinaryViewBuilder::with_capacity(len, data_size), + &sep, + &columns, + len, + ), + other => plan_err!("concat_ws does not support return type {other}"), } } @@ -398,6 +333,41 @@ impl ScalarUDFImpl for ConcatWsFunc { } } +/// Build a `concat_ws` output array using a generic [`ConcatBuilder`]. +/// Write non-null column values per row, inserting the separator between them +fn build_concat_ws( + mut builder: B, + sep: &ColumnarValueRef, + columns: &[ColumnarValueRef], + len: usize, +) -> Result { + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset()?; + continue; + } + let mut first = true; + for column in columns { + if column.is_valid(i) { + if !first { + builder.write::(sep, i)?; + } + builder.write::(column, i)?; + first = false; + } + } + builder.append_offset()?; + } + let array = builder.finish(sep.nulls())?; + Ok(ColumnarValue::Array(array)) +} + +fn null_scalar(dt: &DataType) -> ColumnarValue { + ColumnarValue::Scalar( + ScalarValue::try_new_null(dt).unwrap_or(ScalarValue::Utf8(None)), + ) +} + fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { // Preserve the delimiter's string type for any new literals produced // during simplification. @@ -406,6 +376,17 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result DataType::Utf8, }; + // Shortcut for binary delimiters + if is_binary(&delimiter_type) { + let mut args = args + .iter() + .filter(|x| !is_null(x)) + .cloned() + .collect::>(); + args.insert(0, delimiter.clone()); + return Ok(ExprSimplifyResult::Original(args)); + } + let typed_lit = |s: String| -> Expr { match delimiter_type { DataType::LargeUtf8 => lit(ScalarValue::LargeUtf8(Some(s))), @@ -532,8 +513,11 @@ mod tests { use std::sync::Arc; use crate::string::concat_ws::ConcatWsFunc; - use arrow::array::{Array, ArrayRef, LargeStringArray, StringArray, StringViewArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; + use arrow::array::{ + Array, ArrayRef, BinaryArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, + }; + use arrow::datatypes::DataType::{Binary, LargeBinary, LargeUtf8, Utf8, Utf8View}; use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_common::ScalarValue; @@ -934,4 +918,116 @@ mod tests { Ok(()) } + + #[test] + fn concat_ws_binary_scalars() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::Binary(Some(b"|".to_vec()))); + let c1 = ColumnarValue::Scalar(ScalarValue::Binary(Some(b"aa".to_vec()))); + let c2 = ColumnarValue::Scalar(ScalarValue::Binary(None)); + let c3 = ColumnarValue::Scalar(ScalarValue::Binary(Some(b"cc".to_vec()))); + + let arg_fields = vec![ + Field::new("a", Binary, true).into(), + Field::new("a", Binary, true).into(), + Field::new("a", Binary, true).into(), + Field::new("a", Binary, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2, c3], + arg_fields, + number_rows: 1, + return_field: Field::new("f", Binary, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + let result = ConcatWsFunc::new().invoke_with_args(args)?; + match result { + ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) => { + assert_eq!(v, b"aa|cc"); + } + other => panic!("Expected Binary scalar, got {other:?}"), + } + + Ok(()) + } + + #[test] + fn concat_ws_binary_arrays() -> Result<()> { + let c0 = ColumnarValue::Scalar(ScalarValue::Binary(Some(b",".to_vec()))); + let c1 = ColumnarValue::Array(Arc::new(BinaryArray::from_vec(vec![ + b"foo".as_ref(), + b"bar", + b"baz", + ]))); + let c2 = ColumnarValue::Array(Arc::new(LargeBinaryArray::from_opt_vec(vec![ + Some(b"x".as_ref()), + None, + Some(b"z"), + ]))); + + let arg_fields = vec![ + Field::new("a", Binary, true).into(), + Field::new("a", Binary, true).into(), + Field::new("a", LargeBinary, true).into(), + ]; + let args = ScalarFunctionArgs { + args: vec![c0, c1, c2], + arg_fields, + number_rows: 3, + return_field: Field::new("f", LargeBinary, true).into(), + config_options: Arc::new(ConfigOptions::default()), + }; + + let result = ConcatWsFunc::new().invoke_with_args(args)?; + let expected = Arc::new(LargeBinaryArray::from_opt_vec(vec![ + Some(b"foo,x".as_ref()), + Some(b"bar"), + Some(b"baz,z"), + ])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => assert_eq!(&expected, array), + _ => panic!("Expected array result"), + } + + Ok(()) + } + // + // #[test] + // fn concat_ws_large_binary_arrays() -> Result<()> { + // let c0 = ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(b",".to_vec()))); + // let c1 = ColumnarValue::Array(Arc::new(LargeBinaryArray::from_vec(vec![ + // b"foo".as_ref(), + // b"bar", + // b"baz", + // ]))); + // let c2 = ColumnarValue::Array(Arc::new(LargeBinaryArray::from_opt_vec(vec![ + // Some(b"x".as_ref()), + // None, + // Some(b"z"), + // ]))); + // + // let arg_fields = vec![ + // Field::new("a", LargeBinary, true).into(), + // Field::new("a", LargeBinary, true).into(), + // Field::new("a", LargeBinary, true).into(), + // ]; + // let args = ScalarFunctionArgs { + // args: vec![c0, c1, c2], + // arg_fields, + // number_rows: 3, + // return_field: Field::new("f", LargeBinary, true).into(), + // config_options: Arc::new(ConfigOptions::default()), + // }; + // + // let result = ConcatWsFunc::new().invoke_with_args(args)?; + // let expected = Arc::new(LargeBinaryArray::from_opt_vec(vec![ + // Some(b"foo,x".as_ref()), + // Some(b"bar"), + // Some(b"baz,z"), + // ])) as ArrayRef; + // match &result { + // ColumnarValue::Array(array) => assert_eq!(&expected, array), + // _ => panic!("expected array result"), + // } + // Ok(()) + // } } diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs index 1d02def4765cc..9da932ee012d8 100644 --- a/datafusion/functions/src/strings.rs +++ b/datafusion/functions/src/strings.rs @@ -19,18 +19,40 @@ use std::marker::PhantomData; use std::mem::size_of; use std::sync::Arc; -use datafusion_common::{Result, exec_datafusion_err, internal_err}; +use datafusion_common::{ + Result, ScalarValue, exec_datafusion_err, exec_err, internal_err, plan_err, +}; use arrow::array::{ - Array, ArrayAccessor, ArrayDataBuilder, ArrayRef, BinaryArray, ByteView, - GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, - make_view, + Array, ArrayAccessor, ArrayDataBuilder, ArrayRef, BinaryArray, BinaryViewArray, + ByteView, FixedSizeBinaryArray, GenericStringArray, LargeBinaryArray, + LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, + as_largestring_array, make_view, }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer, ScalarBuffer}; use arrow::datatypes::DataType; +use arrow_buffer::ArrowNativeType; +use datafusion_common::cast::{ + as_binary_array, as_binary_view_array, as_fixed_size_binary_array, + as_large_binary_array, as_string_array, as_string_view_array, +}; +use datafusion_expr_common::columnar_value::ColumnarValue; + +/// Trait abstracting concatenating string and binary collections. +pub(crate) trait ConcatBuilder { + fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) -> Result<()>; + + fn append_offset(&mut self) -> Result<()>; -/// Builder used by `concat`/`concat_ws` to assemble a [`StringArray`] one row -/// at a time from multiple input columns. + fn finish(self, null_buffer: Option) -> Result; +} + +/// Builder used by `concat`/`concat_ws` to assemble a [`GenericStringArray`] +/// (`StringArray` or `LargeStringArray`) one row at a time from multiple input columns. /// /// Each row is written via repeated `write` calls (one per input fragment) /// followed by a single `append_offset` to commit the row. The output null @@ -39,39 +61,46 @@ use arrow::datatypes::DataType; /// /// For the common "produce one `&str` per row" pattern, prefer /// `GenericStringArrayBuilder` instead. -pub(crate) struct ConcatStringBuilder { +pub(crate) struct ConcatGenericStringBuilder { offsets_buffer: MutableBuffer, value_buffer: MutableBuffer, - /// If true, a safety check is required during the `finish` call - tainted: bool, + _phantom: PhantomData, } +pub(crate) type ConcatStringBuilder = ConcatGenericStringBuilder; +pub(crate) type ConcatLargeStringBuilder = ConcatGenericStringBuilder; -impl ConcatStringBuilder { +impl ConcatGenericStringBuilder { pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { let capacity = item_capacity .checked_add(1) - .map(|i| i.saturating_mul(size_of::())) + .map(|i| i.saturating_mul(size_of::())) .expect("capacity integer overflow"); let mut offsets_buffer = MutableBuffer::with_capacity(capacity); // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i32) }; + unsafe { offsets_buffer.push_unchecked(O::usize_as(0)) }; Self { offsets_buffer, value_buffer: MutableBuffer::with_capacity(data_capacity), - tainted: false, + _phantom: PhantomData, } } +} - pub fn write( +impl ConcatBuilder + for ConcatGenericStringBuilder +{ + fn write( &mut self, column: &ColumnarValueRef, i: usize, - ) { + ) -> Result<()> { match column { ColumnarValueRef::Scalar(s) => { + std::str::from_utf8(s).map_err(|_| { + exec_datafusion_err!("concat: scalar bytes are not valid UTF-8") + })?; self.value_buffer.extend_from_slice(s); - self.tainted = true; } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { @@ -91,12 +120,6 @@ impl ConcatStringBuilder { .extend_from_slice(array.value(i).as_bytes()); } } - ColumnarValueRef::NullableBinaryArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer.extend_from_slice(array.value(i)); - } - self.tainted = true; - } ColumnarValueRef::NonNullableArray(array) => { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); @@ -109,32 +132,31 @@ impl ConcatStringBuilder { self.value_buffer .extend_from_slice(array.value(i).as_bytes()); } - ColumnarValueRef::NonNullableBinaryArray(array) => { - self.value_buffer.extend_from_slice(array.value(i)); - self.tainted = true; + _ => { + return exec_err!( + "concat: unexpected column type for string builder: {column:?}" + ); } } + Ok(()) } - pub fn append_offset(&mut self) -> Result<()> { - let next_offset: i32 = self - .value_buffer - .len() - .try_into() - .map_err(|_| exec_datafusion_err!("byte array offset overflow"))?; + fn append_offset(&mut self) -> Result<()> { + let next_offset: O = O::from_usize(self.value_buffer.len()) + .ok_or_else(|| exec_datafusion_err!("byte array offset overflow"))?; self.offsets_buffer.push(next_offset); Ok(()) } - /// Finalize the builder into a concrete [`StringArray`]. + /// Finalize the builder into a concrete [`GenericStringArray`]. /// /// # Errors /// /// Returns an error when: /// /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. - pub fn finish(self, null_buffer: Option) -> Result { - let row_count = self.offsets_buffer.len() / size_of::() - 1; + fn finish(self, null_buffer: Option) -> Result { + let row_count = self.offsets_buffer.len() / size_of::() - 1; if let Some(ref null_buffer) = null_buffer && null_buffer.len() != row_count { @@ -142,22 +164,16 @@ impl ConcatStringBuilder { "Null buffer and offsets buffer must be the same length" ); } - let array_builder = ArrayDataBuilder::new(DataType::Utf8) + let array_builder = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) .len(row_count) .add_buffer(self.offsets_buffer.into()) .add_buffer(self.value_buffer.into()) .nulls(null_buffer); - if self.tainted { - // Raw binary arrays with possible invalid utf-8 were used, - // so let ArrayDataBuilder perform validation - let array_data = array_builder.build()?; - Ok(StringArray::from(array_data)) - } else { - // SAFETY: all data that was appended was valid UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - Ok(StringArray::from(array_data)) - } + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + let array = GenericStringArray::::from(array_data); + Ok(Arc::new(array)) } } @@ -175,8 +191,6 @@ pub(crate) struct ConcatStringViewBuilder { views: Vec, data: Vec, block: Vec, - /// If true, a safety check is required during the `append_offset` call - tainted: bool, } impl ConcatStringViewBuilder { @@ -185,19 +199,22 @@ impl ConcatStringViewBuilder { views: Vec::with_capacity(item_capacity), data: Vec::with_capacity(data_capacity), block: vec![], - tainted: false, } } +} - pub fn write( +impl ConcatBuilder for ConcatStringViewBuilder { + fn write( &mut self, column: &ColumnarValueRef, i: usize, - ) { + ) -> Result<()> { match column { ColumnarValueRef::Scalar(s) => { + std::str::from_utf8(s).map_err(|_| { + exec_datafusion_err!("concat: scalar bytes are not valid UTF-8") + })?; self.block.extend_from_slice(s); - self.tainted = true; } ColumnarValueRef::NullableArray(array) => { if !CHECK_VALID || array.is_valid(i) { @@ -214,12 +231,6 @@ impl ConcatStringViewBuilder { self.block.extend_from_slice(array.value(i).as_bytes()); } } - ColumnarValueRef::NullableBinaryArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.block.extend_from_slice(array.value(i)); - } - self.tainted = true; - } ColumnarValueRef::NonNullableArray(array) => { self.block.extend_from_slice(array.value(i).as_bytes()); } @@ -229,21 +240,18 @@ impl ConcatStringViewBuilder { ColumnarValueRef::NonNullableStringViewArray(array) => { self.block.extend_from_slice(array.value(i).as_bytes()); } - ColumnarValueRef::NonNullableBinaryArray(array) => { - self.block.extend_from_slice(array.value(i)); - self.tainted = true; + _ => { + return exec_err!( + "concat: unexpected column type for string view builder: {column:?}" + ); } } + Ok(()) } /// Finalizes the current row by converting the accumulated data into a /// StringView and appending it to the views buffer. - pub fn append_offset(&mut self) -> Result<()> { - if self.tainted { - std::str::from_utf8(&self.block) - .map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))?; - } - + fn append_offset(&mut self) -> Result<()> { let v = &self.block; if v.len() > 12 { let offset: u32 = self @@ -258,7 +266,6 @@ impl ConcatStringViewBuilder { } self.block.clear(); - self.tainted = false; Ok(()) } @@ -269,7 +276,7 @@ impl ConcatStringViewBuilder { /// Returns an error when: /// /// - the provided `null_buffer` length does not match the row count. - pub fn finish(self, null_buffer: Option) -> Result { + fn finish(self, null_buffer: Option) -> Result { if let Some(ref nulls) = null_buffer && nulls.len() != self.views.len() { @@ -287,8 +294,8 @@ impl ConcatStringViewBuilder { }; // SAFETY: views were constructed with correct lengths, offsets, and - // prefixes. UTF-8 validity was checked in append_offset() for any row - // where tainted data (e.g., binary literals) was appended. + // prefixes. All input fragments came from string arrays or string + // scalars, all of which are valid UTF-8. let array = unsafe { StringViewArray::new_unchecked( ScalarBuffer::from(self.views), @@ -296,135 +303,7 @@ impl ConcatStringViewBuilder { null_buffer, ) }; - Ok(array) - } -} - -/// Builder used by `concat`/`concat_ws` to assemble a [`LargeStringArray`] one -/// row at a time from multiple input columns. See [`ConcatStringBuilder`] for -/// details on the row-composition contract. -/// -/// For the common "produce one `&str` per row" pattern, prefer -/// `GenericStringArrayBuilder` instead. -pub(crate) struct ConcatLargeStringBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, - /// If true, a safety check is required during the `finish` call - tainted: bool, -} - -impl ConcatLargeStringBuilder { - pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let capacity = item_capacity - .checked_add(1) - .map(|i| i.saturating_mul(size_of::())) - .expect("capacity integer overflow"); - - let mut offsets_buffer = MutableBuffer::with_capacity(capacity); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i64) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - tainted: false, - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - self.tainted = true; - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableLargeStringArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableStringViewArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NullableBinaryArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer.extend_from_slice(array.value(i)); - } - self.tainted = true; - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableLargeStringArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableStringViewArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - ColumnarValueRef::NonNullableBinaryArray(array) => { - self.value_buffer.extend_from_slice(array.value(i)); - self.tainted = true; - } - } - } - - pub fn append_offset(&mut self) -> Result<()> { - let next_offset: i64 = self - .value_buffer - .len() - .try_into() - .map_err(|_| exec_datafusion_err!("byte array offset overflow"))?; - self.offsets_buffer.push(next_offset); - Ok(()) - } - - /// Finalize the builder into a concrete [`LargeStringArray`]. - /// - /// # Errors - /// - /// Returns an error when: - /// - /// - the provided `null_buffer` is not the same length as the `offsets_buffer`. - pub fn finish(self, null_buffer: Option) -> Result { - let row_count = self.offsets_buffer.len() / size_of::() - 1; - if let Some(ref null_buffer) = null_buffer - && null_buffer.len() != row_count - { - return internal_err!( - "Null buffer and offsets buffer must be the same length" - ); - } - let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) - .len(row_count) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - if self.tainted { - // Raw binary arrays with possible invalid utf-8 were used, - // so let ArrayDataBuilder perform validation - let array_data = array_builder.build()?; - Ok(LargeStringArray::from(array_data)) - } else { - // SAFETY: all data that was appended was valid Large UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - Ok(LargeStringArray::from(array_data)) - } + Ok(Arc::new(array)) } } @@ -1157,6 +1036,12 @@ pub(crate) enum ColumnarValueRef<'a> { NonNullableStringViewArray(&'a StringViewArray), NullableBinaryArray(&'a BinaryArray), NonNullableBinaryArray(&'a BinaryArray), + NullableLargeBinaryArray(&'a LargeBinaryArray), + NonNullableLargeBinaryArray(&'a LargeBinaryArray), + NullableFixedSizeBinaryArray(&'a FixedSizeBinaryArray), + NonNullableFixedSizeBinaryArray(&'a FixedSizeBinaryArray), + NullableBinaryViewArray(&'a BinaryViewArray), + NonNullableBinaryViewArray(&'a BinaryViewArray), } impl ColumnarValueRef<'_> { @@ -1167,11 +1052,17 @@ impl ColumnarValueRef<'_> { | Self::NonNullableArray(_) | Self::NonNullableLargeStringArray(_) | Self::NonNullableStringViewArray(_) - | Self::NonNullableBinaryArray(_) => true, + | Self::NonNullableBinaryArray(_) + | Self::NonNullableLargeBinaryArray(_) + | Self::NonNullableBinaryViewArray(_) + | Self::NonNullableFixedSizeBinaryArray(_) => true, Self::NullableArray(array) => array.is_valid(i), Self::NullableStringViewArray(array) => array.is_valid(i), Self::NullableLargeStringArray(array) => array.is_valid(i), Self::NullableBinaryArray(array) => array.is_valid(i), + Self::NullableLargeBinaryArray(array) => array.is_valid(i), + Self::NullableBinaryViewArray(array) => array.is_valid(i), + Self::NullableFixedSizeBinaryArray(array) => array.is_valid(i), } } @@ -1182,15 +1073,182 @@ impl ColumnarValueRef<'_> { | Self::NonNullableArray(_) | Self::NonNullableStringViewArray(_) | Self::NonNullableLargeStringArray(_) - | Self::NonNullableBinaryArray(_) => None, + | Self::NonNullableBinaryArray(_) + | Self::NonNullableLargeBinaryArray(_) + | Self::NonNullableBinaryViewArray(_) + | Self::NonNullableFixedSizeBinaryArray(_) => None, Self::NullableArray(array) => array.nulls().cloned(), Self::NullableStringViewArray(array) => array.nulls().cloned(), Self::NullableLargeStringArray(array) => array.nulls().cloned(), Self::NullableBinaryArray(array) => array.nulls().cloned(), + Self::NullableLargeBinaryArray(array) => array.nulls().cloned(), + Self::NullableBinaryViewArray(array) => array.nulls().cloned(), + Self::NullableFixedSizeBinaryArray(array) => array.nulls().cloned(), + } + } + + /// Parse a [`ColumnarValue`] argument into `ColumnarValueRef`. + /// Returns `None` when the argument is null or null scalar + /// Returns an error when a columnar value type is not supported. + /// Shared by `concat` and `concat_ws`. + pub(crate) fn from_columnar_value<'a>( + col: &'a ColumnarValue, + data_size: &mut usize, + len: usize, + size_factor: usize, + convert_to_str: bool, + ) -> Result>> { + match col { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { + if let Some(s) = maybe_value { + *data_size += s.len() * len * size_factor; + Ok(Some(ColumnarValueRef::Scalar(s.as_bytes()))) + } else { + Ok(None) + } + } + ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::LargeBinary(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::BinaryView(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::FixedSizeBinary(_, maybe_value)) => { + if let Some(b) = maybe_value { + *data_size += b.len() * len * size_factor; + Ok(Some(ColumnarValueRef::Scalar(b.as_slice()))) + } else { + Ok(None) + } + } + ColumnarValue::Scalar(scalar) if scalar.is_null() => { + // null scalar is skipped + Ok(None) + } + ColumnarValue::Scalar(scalar) if convert_to_str => { + match scalar.try_as_str() { + Some(Some(s)) => { + *data_size += s.len() * len * size_factor; + Ok(Some(ColumnarValueRef::Scalar(s.as_bytes()))) + } + Some(None) => unreachable!("null handled above"), + None => { + internal_err!("Expected string or binary, got {scalar:?}") + } + } + } + ColumnarValue::Array(array) => match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + *data_size += string_array.values().len() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + Ok(Some(column)) + } + DataType::LargeUtf8 => { + let string_array = as_largestring_array(array); + *data_size += string_array.values().len() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + }; + Ok(Some(column)) + } + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + *data_size += string_array.total_buffer_bytes_used() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + }; + Ok(Some(column)) + } + DataType::Binary => { + let binary_array = as_binary_array(array)?; + *data_size += binary_array.values().len() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableBinaryArray(binary_array) + } else { + ColumnarValueRef::NonNullableBinaryArray(binary_array) + }; + Ok(Some(column)) + } + DataType::LargeBinary => { + let binary_array = as_large_binary_array(array)?; + *data_size += binary_array.values().len() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableLargeBinaryArray(binary_array) + } else { + ColumnarValueRef::NonNullableLargeBinaryArray(binary_array) + }; + Ok(Some(column)) + } + DataType::BinaryView => { + let binary_array = as_binary_view_array(array)?; + *data_size += binary_array.total_buffer_bytes_used() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableBinaryViewArray(binary_array) + } else { + ColumnarValueRef::NonNullableBinaryViewArray(binary_array) + }; + Ok(Some(column)) + } + DataType::FixedSizeBinary(_) => { + let binary_array = as_fixed_size_binary_array(array)?; + *data_size += binary_array.values().len() * size_factor; + let column = if array.is_nullable() { + ColumnarValueRef::NullableFixedSizeBinaryArray(binary_array) + } else { + ColumnarValueRef::NonNullableFixedSizeBinaryArray(binary_array) + }; + Ok(Some(column)) + } + other => { + plan_err!( + "Input was {other} which is not a supported datatype for concat function" + ) + } + }, + _ => { + plan_err!( + "Input was {col} which is not a supported datatype for concat function" + ) + } } } } +/// Return the widest binary type found in `types`. +/// Order: `BinaryView` > `LargeBinary` / `FixedSizeBinary` > `Binary`. +pub(crate) fn widest_binary_type(types: &[DataType]) -> DataType { + if types.iter().any(|t| matches!(t, DataType::BinaryView)) { + DataType::BinaryView + } else if types + .iter() + .any(|t| matches!(t, DataType::LargeBinary | DataType::FixedSizeBinary(_))) + { + DataType::LargeBinary + } else { + DataType::Binary + } +} + +/// Return the widest string type found in `types`. +/// Order: `Utf8View` > `LargeUtf8` > `Utf8`. +pub(crate) fn widest_string_type(types: &[DataType]) -> DataType { + if types.iter().any(|t| matches!(t, DataType::Utf8View)) { + DataType::Utf8View + } else if types.iter().any(|t| matches!(t, DataType::LargeUtf8)) { + DataType::LargeUtf8 + } else { + DataType::Utf8 + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 57fd6cadd9dde..70f1e12952308 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -80,19 +80,15 @@ impl ScalarUDFImpl for SparkConcat { ) } fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result { - use DataType::*; - // Spark semantics: concat returns NULL if ANY input is NULL let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); - // Determine return type: Utf8View > LargeUtf8 > Utf8 - let mut dt = &Utf8; - for field in args.arg_fields { - let data_type = field.data_type(); - if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) { - dt = data_type; - } - } + let arg_types: Vec = args + .arg_fields + .iter() + .map(|f| f.data_type().clone()) + .collect(); + let dt = ConcatFunc::new().return_type(&arg_types)?; Ok(Arc::new(Field::new("concat", dt.clone(), nullable))) } @@ -113,17 +109,9 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // Handle zero-argument case: return empty string if arg_values.is_empty() { let return_type = return_field.data_type(); - return match return_type { - DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::new(), - )))), - DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8( - Some(String::new()), - ))), - _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8( - Some(String::new()), - ))), - }; + return Ok(ColumnarValue::Scalar(ScalarValue::new_default( + return_type, + )?)); } // Step 1: Check for NULL mask in incoming args @@ -132,13 +120,9 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // If all scalars and any is NULL, return NULL immediately if matches!(null_mask, NullMaskResolution::ReturnNull) { let return_type = return_field.data_type(); - return match return_type { - DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))), - DataType::LargeUtf8 => { - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None))) - } - _ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), - }; + return Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + return_type, + )?)); } // Step 2: Delegate to DataFusion's concat diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 8396a60137ee1..329179b5cf79e 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -856,15 +856,6 @@ datafusion public string_agg 1 IN expression String NULL false 1 datafusion public string_agg 2 IN delimiter String NULL false 1 datafusion public string_agg 1 OUT NULL String NULL false 1 -# test variable length arguments -query TTTBI rowsort -select specific_name, data_type, parameter_mode, is_variadic, rid from information_schema.parameters where specific_name = 'concat'; ----- -concat Binary IN true 0 -concat String IN true 1 -concat String OUT false 0 -concat String OUT false 1 - # test ceorcion signature query TTITI rowsort select specific_name, data_type, ordinal_position, parameter_mode, rid from information_schema.parameters where specific_name = 'repeat'; diff --git a/datafusion/sqllogictest/test_files/spark/string/concat.slt b/datafusion/sqllogictest/test_files/spark/string/concat.slt index df539a1c7a159..2ec567df9c13f 100644 --- a/datafusion/sqllogictest/test_files/spark/string/concat.slt +++ b/datafusion/sqllogictest/test_files/spark/string/concat.slt @@ -26,6 +26,7 @@ SELECT concat(arrow_cast('Spark', 'Utf8View'), arrow_cast('SQL', 'Utf8View')), a ---- SparkSQL Utf8View +# A major difference from the generic `concat` query T SELECT concat('Spark', 'SQL', NULL); ---- @@ -71,67 +72,22 @@ SELECT concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')), a ---- abc Utf8View -# Test mixed types: Utf8 + Binary -query TT -SELECT concat(arrow_cast('hello', 'Utf8'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Utf8'), arrow_cast(' world', 'Binary'))); ----- -hello world Utf8 +# Coercion rules from Binary to Utf8 do no apply compared to generic `concat`, +# so `concat` produces an explicit error +query error Error during planning: concat does not support mixed string and binary inputs +SELECT concat(arrow_cast('hello', 'Utf8'), arrow_cast(' world', 'Binary')); -# Test mixed types: Utf8View + Binary -query TT -SELECT concat(arrow_cast('hello', 'Utf8View'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Utf8View'), arrow_cast(' world', 'Binary'))); ----- -hello world Utf8View +query error Error during planning: concat does not support mixed string and binary inputs +SELECT concat(arrow_cast('hello', 'Utf8View'), arrow_cast(' world', 'Binary')); -# Test mixed types: Binary + Binary -query TT +# Test Binary + Binary +query ?T SELECT concat(arrow_cast('hello', 'Binary'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Binary'), arrow_cast(' world', 'Binary'))); ---- -hello world Utf8 - -# Test mixed types with ws: Binary + Binary -query TT -SELECT concat_ws('|', arrow_cast('hello', 'Binary'), arrow_cast('world', 'Binary')), arrow_typeof(concat_ws('|', arrow_cast('hello', 'Binary'), arrow_cast('world', 'Binary'))); ----- -hello|world Utf8 - -# Invalid UTF8 binaries for concatenation, scalar case -# 636166c3a9 = café , where c3a9 is a char é -# 68656c6c6f = hello -query error Execution error: invalid UTF-8 in binary literal -SELECT concat(x'636166c3', x'68656c6c6f'); - -query error Execution error: invalid UTF-8 in binary literal -SELECT concat(x'636166c3', arrow_cast(x'68656c6c6f', 'Utf8View')); +68656c6c6f20776f726c64 Binary -statement ok -create table t as values (x'636166c3', x'68656c6c6f'); - -# Invalid UTF8 sequence for concatenation, array case -query error Arrow error: Invalid argument error: Invalid UTF8 sequence at string -SELECT concat(column1, column2) from t; - -# Invalid UTF8 sequence for concatenation, array case -query error DataFusion error: Execution error: invalid UTF-8 in binary literal -SELECT concat(column1, arrow_cast(column2, 'Utf8View')) from t; - -statement ok -drop table t - -statement ok -create table t as values (x'636166c3', x'a968656c6c6f'); - -# Invalid UTF8 binaries make a valid UTF8 sequence after concatenation, array case -query T -SELECT concat(column1, column2) from t; ----- -caféhello - -statement ok -drop table t - -# Invalid UTF8 binaries make a valid UTF8 sequence after concatenation, scalar case -query T +# Test Binary + Binary, binary literals +query ? SELECT concat(x'636166c3', x'a968656c6c6f'); ---- -caféhello +636166c3a968656c6c6f diff --git a/datafusion/sqllogictest/test_files/string/concat.slt b/datafusion/sqllogictest/test_files/string/concat.slt new file mode 100644 index 0000000000000..c9656bf7988f4 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/concat.slt @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# +# tests for concat and concat_ws +# + +# Test two Utf8View inputs: value and return type +query TT +SELECT concat(arrow_cast('Foo', 'Utf8View'), arrow_cast('Bar', 'Utf8View')), arrow_typeof(concat(arrow_cast('Foo', 'Utf8View'), arrow_cast('Bar', 'Utf8View'))); +---- +FooBar Utf8View + +query T +SELECT concat('Foo', 'Bar', NULL); +---- +FooBar + +query T +SELECT concat('', '1', '', '2'); +---- +12 + +query error does not support zero arguments +SELECT concat(); + +query T +SELECT concat(''); +---- +(empty) + +query T +SELECT concat(a, b, c) from (select 'a' a, 'b' b, 'c' c union all select null a, 'b', 'c') order by 1 nulls last; +---- +abc +bc + +# Test mixed types: Utf8View + Utf8 +query TT +SELECT concat(arrow_cast('hello', 'Utf8View'), ' world'), arrow_typeof(concat(arrow_cast('hello', 'Utf8View'), ' world')); +---- +hello world Utf8View + +# Test mixed string types +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'))); +---- +ab LargeUtf8 + +# Test types mixed together +query TT +SELECT concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View')), arrow_typeof(concat('a', arrow_cast('b', 'LargeUtf8'), arrow_cast('c', 'Utf8View'))); +---- +abc Utf8View + +# Mixed Utf8 + Binary is denied +query error does not support mixed string and binary inputs +SELECT concat(arrow_cast('hello', 'Utf8'), arrow_cast(' world', 'Binary')); + +# binary separator is denied for string arguments +query error does not support mixed string and binary inputs +SELECT concat_ws(x'7c', 'hello', 'world'); + +# null separator +query T +SELECT concat_ws(NULL, 'hello', 'world'); +---- +NULL + +# Test Binary + Binary scalar concat +query ?T +SELECT concat(arrow_cast('hello', 'Binary'), arrow_cast(' world', 'Binary')), arrow_typeof(concat(arrow_cast('hello', 'Binary'), arrow_cast(' world', 'Binary'))); +---- +68656c6c6f20776f726c64 Binary + +# Test all binary types together: widened to BinaryView +query ?T +SELECT concat(arrow_cast('hello', 'Binary'), arrow_cast('there', 'BinaryView'), arrow_cast('world', 'LargeBinary')), arrow_typeof(concat(arrow_cast('hello', 'Binary'), arrow_cast('there', 'BinaryView'), arrow_cast('world', 'LargeBinary'))); +---- +68656c6c6f7468657265776f726c64 BinaryView + +# Test all binary types together with concat_ws: widened to BinaryView +query ?T +SELECT concat_ws(x'7c', arrow_cast('hello', 'Binary'), arrow_cast(' there', 'BinaryView'), arrow_cast(' world', 'LargeBinary')), arrow_typeof(concat_ws(x'7c', arrow_cast('hello', 'Binary'), arrow_cast(' there', 'BinaryView'), arrow_cast(' world', 'LargeBinary'))); +---- +68656c6c6f7c2074686572657c20776f726c64 BinaryView + +query TT +SELECT concat_ws('|', arrow_cast('hello', 'Utf8View'), 'world'), arrow_typeof(concat_ws('|', arrow_cast('hello', 'Utf8View'), ' world')); +---- +hello|world Utf8View + +query TT +SELECT concat_ws('|', arrow_cast('hello', 'Utf8View'), arrow_cast('there', 'LargeUtf8'), arrow_cast('world', 'Utf8')), arrow_typeof(concat_ws('|', arrow_cast('hello', 'Utf8View'), arrow_cast('there', 'LargeUtf8'), arrow_cast('world', 'Utf8'))); +---- +hello|there|world Utf8View + +# Test Binary + Binary scalar concat +query ?T +SELECT concat_ws(x'7c', arrow_cast('hello', 'Binary'), arrow_cast('world', 'Binary')), arrow_typeof(concat_ws(x'7c', arrow_cast('hello', 'Binary'), arrow_cast('world', 'Binary'))); +---- +68656c6c6f7c776f726c64 Binary + +statement ok +create table t as values (x'636166c3a9', x'68656c6c6f'); + +# Test binary + binary array concat +query ? +SELECT concat(column1, column2) from t; +---- +636166c3a968656c6c6f + +# Test binary + binary array concat_ws +query ? +SELECT concat_ws(x'7c', column1, column2) from t; +---- +636166c3a97c68656c6c6f + +statement ok +drop table t