Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 300 additions & 16 deletions crates/core/src/codec.rs

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ mod udf;
pub mod udtf;
mod udwf;

// Re-export helpers previously consumed by downstream Rust crates.
// Modules stay private to keep the public Rust API surface small.
pub use udaf::to_rust_accumulator;
pub use udwf::{MultiColumnWindowUDF, PythonFunctionWindowUDF, to_rust_partition_evaluator};

#[cfg(feature = "mimalloc")]
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
Expand Down
190 changes: 175 additions & 15 deletions crates/core/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ use std::ptr::NonNull;
use std::sync::Arc;

use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::common::ScalarValue;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion::logical_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_ffi::udaf::FFI_AggregateUDF;
use datafusion_python_util::parse_volatility;
Expand Down Expand Up @@ -144,15 +145,168 @@ impl Accumulator for RustAccumulator {
}
}

fn instantiate_accumulator(accum: &Py<PyAny>) -> Result<Box<dyn Accumulator>> {
let instance = Python::attach(|py| {
accum
.call0(py)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})?;
Ok(Box::new(RustAccumulator::new(instance)))
}

/// Wrap a Python accumulator factory in an `AccumulatorFactoryFunction`.
///
/// Retained for downstream callers that previously consumed this
/// helper to build a [`AccumulatorFactoryFunction`] for `create_udaf`
/// or similar factory-based APIs. New in-crate code should construct
/// a [`PythonFunctionAggregateUDF`] directly so the codec can downcast
/// and ship it inline.
pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
let accum = Python::attach(|py| {
accum
.call0(py)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})?;
Ok(Box::new(RustAccumulator::new(accum)))
})
Arc::new(move |_args| instantiate_accumulator(&accum))
}

/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs.
/// Holds the Python accumulator factory directly so the codec can
/// downcast and cloudpickle it across process boundaries.
#[derive(Debug)]
pub(crate) struct PythonFunctionAggregateUDF {
name: String,
accumulator: Py<PyAny>,
signature: Signature,
return_type: DataType,
state_fields: Vec<FieldRef>,
}

impl PythonFunctionAggregateUDF {
fn new(
name: String,
accumulator: Py<PyAny>,
input_types: Vec<DataType>,
return_type: DataType,
state_types: Vec<DataType>,
volatility: Volatility,
) -> Self {
let signature = Signature::exact(input_types, volatility);
let state_fields = state_types
.into_iter()
.enumerate()
.map(|(i, t)| Arc::new(Field::new(format!("state_{i}"), t, true)))
.collect();
Self {
name,
accumulator,
signature,
return_type,
state_fields,
}
}

/// Stored Python callable that returns a fresh accumulator instance
/// per partition. Consumed by the codec to cloudpickle the factory
/// across process boundaries.
pub(crate) fn accumulator(&self) -> &Py<PyAny> {
&self.accumulator
}

pub(crate) fn return_type(&self) -> &DataType {
&self.return_type
}

pub(crate) fn state_fields_ref(&self) -> &[FieldRef] {
&self.state_fields
}

/// Reconstruct a `PythonFunctionAggregateUDF` from the parts emitted
/// by the codec. `state_fields` carries the full state schema
/// (names, data types, nullability, metadata) — the codec extracts
/// it from the IPC payload, so the post-decode state schema is
/// identical to the pre-encode one. Use [`Self::new`] when only
/// `Vec<DataType>` is available (e.g. the Python constructor path,
/// where field names are synthesized).
pub(crate) fn from_parts(
name: String,
accumulator: Py<PyAny>,
input_types: Vec<DataType>,
return_type: DataType,
state_fields: Vec<FieldRef>,
volatility: Volatility,
) -> Self {
Self {
name,
accumulator,
signature: Signature::exact(input_types, volatility),
return_type,
state_fields,
}
}
}

impl Eq for PythonFunctionAggregateUDF {}
impl PartialEq for PythonFunctionAggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.signature == other.signature
&& self.return_type == other.return_type
&& self.state_fields == other.state_fields
// Pointer-identity fast path: `Arc`-shared clones of the
// same UDF skip the GIL roundtrip. Falls through to Python
// `__eq__` only for two distinct callables.
&& (self.accumulator.as_ptr() == other.accumulator.as_ptr()
|| Python::attach(|py| {
// See `PythonFunctionScalarUDF::eq` for the
// rationale on swallowing the exception as `false`
// and logging at `debug`. FIXME: revisit if
// upstream `AggregateUDFImpl` exposes a fallible
// `PartialEq`.
self.accumulator
.bind(py)
.eq(other.accumulator.bind(py))
.unwrap_or_else(|e| {
log::debug!(
target: "datafusion_python::udaf",
"PythonFunctionAggregateUDF {:?} __eq__ raised; treating as unequal: {e}",
self.name,
);
false
})
}))
}
}

impl std::hash::Hash for PythonFunctionAggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// See `PythonFunctionScalarUDF`'s `Hash` impl for the
// rationale: hash the identifying header only and let
// `PartialEq` disambiguate callables.
self.name.hash(state);
self.signature.hash(state);
self.return_type.hash(state);
for f in &self.state_fields {
f.hash(state);
}
}
}

impl AggregateUDFImpl for PythonFunctionAggregateUDF {
fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
instantiate_accumulator(&self.accumulator)
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(self.state_fields.clone())
}
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
Expand Down Expand Up @@ -190,14 +344,15 @@ impl PyAggregateUDF {
state_type: PyArrowType<Vec<DataType>>,
volatility: &str,
) -> PyResult<Self> {
let function = create_udaf(
name,
let py_udf = PythonFunctionAggregateUDF::new(
name.to_string(),
accumulator,
input_type.0,
Arc::new(return_type.0),
return_type.0,
state_type.0,
parse_volatility(volatility)?,
to_rust_accumulator(accumulator),
Arc::new(state_type.0),
);
let function = AggregateUDF::new_from_impl(py_udf);
Ok(Self { function })
}

Expand Down Expand Up @@ -231,4 +386,9 @@ impl PyAggregateUDF {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("AggregateUDF({})", self.function.name()))
}

#[getter]
fn name(&self) -> &str {
self.function.name()
}
}
Loading
Loading