diff --git a/vortex-array/src/expr/exprs/list_contains.rs b/vortex-array/src/expr/exprs/list_contains.rs index 68d7ed31bad..376be519414 100644 --- a/vortex-array/src/expr/exprs/list_contains.rs +++ b/vortex-array/src/expr/exprs/list_contains.rs @@ -20,13 +20,14 @@ use vortex_error::vortex_err; use vortex_mask::Mask; use vortex_vector::BoolDatum; use vortex_vector::Datum; -use vortex_vector::ScalarOps; use vortex_vector::Vector; -use vortex_vector::VectorMutOps; use vortex_vector::VectorOps; +use vortex_vector::bool::BoolScalar; use vortex_vector::bool::BoolVector; use vortex_vector::listview::ListViewScalar; use vortex_vector::listview::ListViewVector; +use vortex_vector::match_each_pvector; +use vortex_vector::primitive::PScalar; use vortex_vector::primitive::PVector; use crate::ArrayRef; @@ -128,30 +129,34 @@ impl VTable for ListContains { .try_into() .map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?; - let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) { + match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) { (true, true) => { + // Early return with Scalar to avoid allocating BitBuffer. let list = lhs.into_scalar().vortex_expect("scalar").into_list(); let needle = rhs.into_scalar().vortex_expect("scalar"); - // Convert the needle scalar to a vector with row_count - // elements and reuse constant_list_scalar_contains - let needle_vector = needle.repeat(args.row_count).freeze(); - constant_list_scalar_contains(list, needle_vector)? + let found = list_contains_scalar_scalar(&list, &needle)?; + Ok(Datum::Scalar(BoolScalar::new(Some(found)).into())) + } + (true, false) => { + let matches = constant_list_scalar_contains( + lhs.into_scalar().vortex_expect("scalar").into_list(), + rhs.into_vector().vortex_expect("vector"), + )?; + Ok(Datum::Vector(matches.into())) + } + (false, true) => { + let matches = list_contains_scalar( + lhs.unwrap_into_vector(args.row_count).into_list(), + rhs.into_scalar().vortex_expect("scalar").into_list(), + )?; + Ok(Datum::Vector(matches.into())) } - (true, false) => constant_list_scalar_contains( - lhs.into_scalar().vortex_expect("scalar").into_list(), - rhs.into_vector().vortex_expect("vector"), - )?, - (false, true) => list_contains_scalar( - lhs.unwrap_into_vector(args.row_count).into_list(), - rhs.into_scalar().vortex_expect("scalar").into_list(), - )?, (false, false) => { vortex_bail!( "ListContains currently only supports constant needle (RHS) or constant list (LHS)" ) } - }; - Ok(Datum::Vector(matches.into())) + } } fn stat_falsification( @@ -330,6 +335,31 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex Ok(result) } +/// Used when both needle and list are scalars. +fn list_contains_scalar_scalar( + list: &ListViewScalar, + needle: &vortex_vector::Scalar, +) -> VortexResult { + let elements = list.value().elements(); + + // Downcast to `PVector` and access slice directly to avoid `scalar_at` overhead. + let found = if let Vector::Primitive(prim) = &**elements { + match_each_pvector!(prim, |pvec| { + let slice: &[_] = pvec.as_ref(); + let validity = pvec.validity(); + slice + .iter() + .enumerate() + .any(|(i, &elem)| needle == &PScalar::new(Some(elem)).into() && validity.value(i)) + }) + } else { + // Fallback for non-primitive vectors + (0..elements.len()).any(|i| needle == &elements.scalar_at(i)) + }; + + Ok(found) +} + /// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a /// [`BoolArray`] of matches on the child elements array. ///