Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions vortex-array/src/expr/exprs/list_contains.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ use vortex_error::vortex_err;
use vortex_mask::Mask;
use vortex_vector::BoolDatum;
use vortex_vector::Datum;
use vortex_vector::ScalarOps;
use vortex_vector::Vector;
use vortex_vector::VectorMutOps;
use vortex_vector::VectorOps;
use vortex_vector::bool::BoolScalar;
use vortex_vector::bool::BoolVector;
use vortex_vector::listview::ListViewScalar;
use vortex_vector::listview::ListViewVector;
use vortex_vector::match_each_pvector;
use vortex_vector::primitive::PScalar;
use vortex_vector::primitive::PVector;

use crate::ArrayRef;
Expand Down Expand Up @@ -128,30 +129,34 @@ impl VTable for ListContains {
.try_into()
.map_err(|_| vortex_err!("Wrong number of arguments for ListContains expression"))?;

let matches = match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
match (lhs.as_scalar().is_some(), rhs.as_scalar().is_some()) {
(true, true) => {
// Early return with Scalar to avoid allocating BitBuffer.
let list = lhs.into_scalar().vortex_expect("scalar").into_list();
let needle = rhs.into_scalar().vortex_expect("scalar");
// Convert the needle scalar to a vector with row_count
// elements and reuse constant_list_scalar_contains
let needle_vector = needle.repeat(args.row_count).freeze();
constant_list_scalar_contains(list, needle_vector)?
let found = list_contains_scalar_scalar(&list, &needle)?;
Ok(Datum::Scalar(BoolScalar::new(Some(found)).into()))
}
(true, false) => {
let matches = constant_list_scalar_contains(
lhs.into_scalar().vortex_expect("scalar").into_list(),
rhs.into_vector().vortex_expect("vector"),
)?;
Ok(Datum::Vector(matches.into()))
}
(false, true) => {
let matches = list_contains_scalar(
lhs.unwrap_into_vector(args.row_count).into_list(),
rhs.into_scalar().vortex_expect("scalar").into_list(),
)?;
Ok(Datum::Vector(matches.into()))
}
(true, false) => constant_list_scalar_contains(
lhs.into_scalar().vortex_expect("scalar").into_list(),
rhs.into_vector().vortex_expect("vector"),
)?,
(false, true) => list_contains_scalar(
lhs.unwrap_into_vector(args.row_count).into_list(),
rhs.into_scalar().vortex_expect("scalar").into_list(),
)?,
(false, false) => {
vortex_bail!(
"ListContains currently only supports constant needle (RHS) or constant list (LHS)"
)
}
};
Ok(Datum::Vector(matches.into()))
}
}

fn stat_falsification(
Expand Down Expand Up @@ -330,6 +335,31 @@ fn constant_list_scalar_contains(list: ListViewScalar, values: Vector) -> Vortex
Ok(result)
}

/// Used when both needle and list are scalars.
fn list_contains_scalar_scalar(
list: &ListViewScalar,
needle: &vortex_vector::Scalar,
Comment on lines +340 to +341
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this method be in vortex-compute?

) -> VortexResult<bool> {
let elements = list.value().elements();

// Downcast to `PVector` and access slice directly to avoid `scalar_at` overhead.
let found = if let Vector::Primitive(prim) = &**elements {
match_each_pvector!(prim, |pvec| {
let slice: &[_] = pvec.as_ref();
let validity = pvec.validity();
slice
.iter()
.enumerate()
.any(|(i, &elem)| needle == &PScalar::new(Some(elem)).into() && validity.value(i))
})
} else {
// Fallback for non-primitive vectors
(0..elements.len()).any(|i| needle == &elements.scalar_at(i))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use a array compare here with the elements and a constant array?

Copy link
Contributor

@joseph-isaacs joseph-isaacs Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need to special case anything

};

Ok(found)
}

/// Returns a [`BitBuffer`] where each bit represents if a list contains the scalar, derived from a
/// [`BoolArray`] of matches on the child elements array.
///
Expand Down
Loading