diff --git a/vortex-array/src/expr/exprs/binary.rs b/vortex-array/src/expr/exprs/binary.rs index cf7765c9cca..fce11a74d70 100644 --- a/vortex-array/src/expr/exprs/binary.rs +++ b/vortex-array/src/expr/exprs/binary.rs @@ -3,10 +3,19 @@ use std::fmt::Formatter; -use arrow_ord::cmp; use prost::Message; -use vortex_compute::arrow::IntoArrow; -use vortex_compute::arrow::IntoVector; +use vortex_compute::arithmetic::Add as AddOp; +use vortex_compute::arithmetic::Arithmetic; +use vortex_compute::arithmetic::Div as DivOp; +use vortex_compute::arithmetic::Mul as MulOp; +use vortex_compute::arithmetic::Sub as SubOp; +use vortex_compute::comparison::Compare; +use vortex_compute::comparison::Equal; +use vortex_compute::comparison::GreaterThan; +use vortex_compute::comparison::GreaterThanOrEqual; +use vortex_compute::comparison::LessThan; +use vortex_compute::comparison::LessThanOrEqual; +use vortex_compute::comparison::NotEqual; use vortex_compute::logical::LogicalAndKleene; use vortex_compute::logical::LogicalOrKleene; use vortex_dtype::DType; @@ -16,7 +25,6 @@ use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_proto::expr as pb; use vortex_vector::Datum; -use vortex_vector::VectorOps; use crate::ArrayRef; use crate::compute; @@ -136,54 +144,32 @@ impl VTable for Binary { .try_into() .map_err(|_| vortex_err!("Wrong arg count"))?; - match op { + let result: Datum = match op { + Operator::Eq => Compare::::compare(lhs, rhs).into(), + Operator::NotEq => Compare::::compare(lhs, rhs).into(), + Operator::Lt => Compare::::compare(lhs, rhs).into(), + Operator::Lte => Compare::::compare(lhs, rhs).into(), + Operator::Gt => Compare::::compare(lhs, rhs).into(), + Operator::Gte => Compare::::compare(lhs, rhs).into(), Operator::And => { - return Ok(LogicalAndKleene::and_kleene(&lhs.into_bool(), &rhs.into_bool()).into()); + LogicalAndKleene::and_kleene(&lhs.into_bool(), &rhs.into_bool()).into() } - Operator::Or => { - return Ok(LogicalOrKleene::or_kleene(&lhs.into_bool(), &rhs.into_bool()).into()); - } - _ => {} - } - - let lhs = lhs.into_arrow()?; - let rhs = rhs.into_arrow()?; - - let vector = match op { - Operator::Eq => cmp::eq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), - Operator::NotEq => cmp::neq(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), - Operator::Gt => cmp::gt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), - Operator::Gte => cmp::gt_eq(lhs.as_ref(), rhs.as_ref())? - .into_vector()? - .into(), - Operator::Lt => cmp::lt(lhs.as_ref(), rhs.as_ref())?.into_vector()?.into(), - Operator::Lte => cmp::lt_eq(lhs.as_ref(), rhs.as_ref())? - .into_vector()? - .into(), - + Operator::Or => LogicalOrKleene::or_kleene(&lhs.into_bool(), &rhs.into_bool()).into(), Operator::Add => { - arrow_arith::numeric::add(lhs.as_ref(), rhs.as_ref())?.into_vector()? + Arithmetic::::eval(lhs.into_primitive(), rhs.into_primitive()).into() } Operator::Sub => { - arrow_arith::numeric::sub(lhs.as_ref(), rhs.as_ref())?.into_vector()? + Arithmetic::::eval(lhs.into_primitive(), rhs.into_primitive()).into() } Operator::Mul => { - arrow_arith::numeric::mul(lhs.as_ref(), rhs.as_ref())?.into_vector()? + Arithmetic::::eval(lhs.into_primitive(), rhs.into_primitive()).into() } Operator::Div => { - arrow_arith::numeric::div(lhs.as_ref(), rhs.as_ref())?.into_vector()? - } - Operator::And | Operator::Or => { - unreachable!("Already dealt with above") + Arithmetic::::eval(lhs.into_primitive(), rhs.into_primitive()).into() } }; - // Arrow computed over scalar datums - if vector.len() == 1 && args.row_count != 1 { - return Ok(Datum::Scalar(vector.scalar_at(0))); - } - - Ok(Datum::Vector(vector)) + Ok(result) } fn stat_falsification(