Skip to content

Commit dac9ec6

Browse files
timsaucerclaude
andauthored
feat: enable pickling for Python aggregate and window UDFs (#1545)
* feat: inline encoding for Python aggregate and window UDFs Extends the PythonLogicalCodec / PythonPhysicalCodec inline encoding introduced for scalar UDFs to also cover Python-defined aggregate and window UDFs. The cloudpickle tuple shape per family is: DFPYUDA (agg) (name, accumulator_factory, input_schema_bytes, return_schema_bytes, state_schema_bytes, volatility_str) DFPYUDW (window) (name, evaluator_factory, input_schema_bytes, return_schema_bytes, volatility_str) Same wire-framing as scalar (family magic + version byte + cloudpickle blob), same schema serde (arrow-rs native IPC), same cached cloudpickle handle. The agg state schema is encoded as a full IPC schema so the post-decode UDF reports the same names + nullability + metadata as the sender — relevant for accumulators whose StateFieldsArgs consumers key off names rather than positional DataType. Required restructuring two existing UDF impls so the codec can grab the Python callable directly: * udaf.rs: replaces create_udaf + AccumulatorFactoryFunction closure with a named PythonFunctionAggregateUDF that stores the Py<PyAny> accumulator factory. Synthesizes state_{i} field names when the Python constructor passes only Vec<DataType>; from_parts preserves the full state schema on the decode side. * udwf.rs: renames MultiColumnWindowUDF -> PythonFunctionWindowUDF, drops the PartitionEvaluatorFactory PtrEq wrapper, stores the Py<PyAny> evaluator directly. PartialEq and Hash get the same pointer-identity fast path + debug-log exception handling already on PythonFunctionScalarUDF. User-facing surface: * AggregateUDF.name and WindowUDF.name properties (parallel to the ScalarUDF.name shipped in PR1). * Existing UDAF/UDWF construction paths are unchanged. The per-session with_python_udf_inlining toggle, sender-side context, strict refusal, and user-guide docs land in PRs 3-4 of this series. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat: restore pub UDAF/UDWF helpers and document inline encoding Re-export `to_rust_accumulator`, `to_rust_partition_evaluator`, and `PythonFunctionWindowUDF` (with a `MultiColumnWindowUDF` alias) by promoting `udaf` and `udwf` to `pub mod` so prior downstream Rust consumers keep their API surface after the inline-encoding refactor. Adds an end-to-end window UDF pickle round-trip test that runs the decoded evaluator over a real session, mirroring the aggregate test. Documents the cloudpickle-based shipping behavior of Python aggregate and window UDFs in the user-guide aggregations and windows pages. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix: address PR #1545 review feedback - Fix CountAcc.merge in pickle test: sum over states[0] (partition counts), not over the list of state fields. The prior implementation only added partition 0's count when merging across partitions. - Drive test_agg_udf_evaluates_after_roundtrip with a two-batch DataFrame so merge actually runs and the round-tripped state-field schema is exercised end-to-end. - Correct PY_AGG_UDF_FAMILY / PY_WINDOW_UDF_FAMILY doc comments and the aggregate block comment to reference "return schema bytes" rather than "return type" / "return_type_bytes" so the docs match the actual on-wire layout. - Keep `udaf` and `udwf` modules private (matching `udf`) and selectively re-export the helpers downstream Rust consumers rely on (`to_rust_accumulator`, `to_rust_partition_evaluator`, `PythonFunctionWindowUDF`, `MultiColumnWindowUDF`) instead of exposing the whole module surface. - Rename codec helpers `*_agg_udf` -> `*_udaf` and `*_window_udf` -> `*_udwf` for naming consistency with the Python public aliases. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent afaeccb commit dac9ec6

10 files changed

Lines changed: 787 additions & 87 deletions

File tree

crates/core/src/codec.rs

Lines changed: 300 additions & 16 deletions
Large diffs are not rendered by default.

crates/core/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ mod udf;
6565
pub mod udtf;
6666
mod udwf;
6767

68+
// Re-export helpers previously consumed by downstream Rust crates.
69+
// Modules stay private to keep the public Rust API surface small.
70+
pub use udaf::to_rust_accumulator;
71+
pub use udwf::{MultiColumnWindowUDF, PythonFunctionWindowUDF, to_rust_partition_evaluator};
72+
6873
#[cfg(feature = "mimalloc")]
6974
#[global_allocator]
7075
static GLOBAL: MiMalloc = MiMalloc;

crates/core/src/udaf.rs

Lines changed: 175 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ use std::ptr::NonNull;
1919
use std::sync::Arc;
2020

2121
use datafusion::arrow::array::ArrayRef;
22-
use datafusion::arrow::datatypes::DataType;
22+
use datafusion::arrow::datatypes::{DataType, Field, FieldRef};
2323
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2424
use datafusion::common::ScalarValue;
2525
use datafusion::error::{DataFusionError, Result};
26+
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
2627
use datafusion::logical_expr::{
27-
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
28+
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
2829
};
2930
use datafusion_ffi::udaf::FFI_AggregateUDF;
3031
use datafusion_python_util::parse_volatility;
@@ -144,15 +145,168 @@ impl Accumulator for RustAccumulator {
144145
}
145146
}
146147

148+
fn instantiate_accumulator(accum: &Py<PyAny>) -> Result<Box<dyn Accumulator>> {
149+
let instance = Python::attach(|py| {
150+
accum
151+
.call0(py)
152+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
153+
})?;
154+
Ok(Box::new(RustAccumulator::new(instance)))
155+
}
156+
157+
/// Wrap a Python accumulator factory in an `AccumulatorFactoryFunction`.
158+
///
159+
/// Retained for downstream callers that previously consumed this
160+
/// helper to build a [`AccumulatorFactoryFunction`] for `create_udaf`
161+
/// or similar factory-based APIs. New in-crate code should construct
162+
/// a [`PythonFunctionAggregateUDF`] directly so the codec can downcast
163+
/// and ship it inline.
147164
pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
148-
Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
149-
let accum = Python::attach(|py| {
150-
accum
151-
.call0(py)
152-
.map_err(|e| DataFusionError::Execution(format!("{e}")))
153-
})?;
154-
Ok(Box::new(RustAccumulator::new(accum)))
155-
})
165+
Arc::new(move |_args| instantiate_accumulator(&accum))
166+
}
167+
168+
/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs.
169+
/// Holds the Python accumulator factory directly so the codec can
170+
/// downcast and cloudpickle it across process boundaries.
171+
#[derive(Debug)]
172+
pub(crate) struct PythonFunctionAggregateUDF {
173+
name: String,
174+
accumulator: Py<PyAny>,
175+
signature: Signature,
176+
return_type: DataType,
177+
state_fields: Vec<FieldRef>,
178+
}
179+
180+
impl PythonFunctionAggregateUDF {
181+
fn new(
182+
name: String,
183+
accumulator: Py<PyAny>,
184+
input_types: Vec<DataType>,
185+
return_type: DataType,
186+
state_types: Vec<DataType>,
187+
volatility: Volatility,
188+
) -> Self {
189+
let signature = Signature::exact(input_types, volatility);
190+
let state_fields = state_types
191+
.into_iter()
192+
.enumerate()
193+
.map(|(i, t)| Arc::new(Field::new(format!("state_{i}"), t, true)))
194+
.collect();
195+
Self {
196+
name,
197+
accumulator,
198+
signature,
199+
return_type,
200+
state_fields,
201+
}
202+
}
203+
204+
/// Stored Python callable that returns a fresh accumulator instance
205+
/// per partition. Consumed by the codec to cloudpickle the factory
206+
/// across process boundaries.
207+
pub(crate) fn accumulator(&self) -> &Py<PyAny> {
208+
&self.accumulator
209+
}
210+
211+
pub(crate) fn return_type(&self) -> &DataType {
212+
&self.return_type
213+
}
214+
215+
pub(crate) fn state_fields_ref(&self) -> &[FieldRef] {
216+
&self.state_fields
217+
}
218+
219+
/// Reconstruct a `PythonFunctionAggregateUDF` from the parts emitted
220+
/// by the codec. `state_fields` carries the full state schema
221+
/// (names, data types, nullability, metadata) — the codec extracts
222+
/// it from the IPC payload, so the post-decode state schema is
223+
/// identical to the pre-encode one. Use [`Self::new`] when only
224+
/// `Vec<DataType>` is available (e.g. the Python constructor path,
225+
/// where field names are synthesized).
226+
pub(crate) fn from_parts(
227+
name: String,
228+
accumulator: Py<PyAny>,
229+
input_types: Vec<DataType>,
230+
return_type: DataType,
231+
state_fields: Vec<FieldRef>,
232+
volatility: Volatility,
233+
) -> Self {
234+
Self {
235+
name,
236+
accumulator,
237+
signature: Signature::exact(input_types, volatility),
238+
return_type,
239+
state_fields,
240+
}
241+
}
242+
}
243+
244+
impl Eq for PythonFunctionAggregateUDF {}
245+
impl PartialEq for PythonFunctionAggregateUDF {
246+
fn eq(&self, other: &Self) -> bool {
247+
self.name == other.name
248+
&& self.signature == other.signature
249+
&& self.return_type == other.return_type
250+
&& self.state_fields == other.state_fields
251+
// Pointer-identity fast path: `Arc`-shared clones of the
252+
// same UDF skip the GIL roundtrip. Falls through to Python
253+
// `__eq__` only for two distinct callables.
254+
&& (self.accumulator.as_ptr() == other.accumulator.as_ptr()
255+
|| Python::attach(|py| {
256+
// See `PythonFunctionScalarUDF::eq` for the
257+
// rationale on swallowing the exception as `false`
258+
// and logging at `debug`. FIXME: revisit if
259+
// upstream `AggregateUDFImpl` exposes a fallible
260+
// `PartialEq`.
261+
self.accumulator
262+
.bind(py)
263+
.eq(other.accumulator.bind(py))
264+
.unwrap_or_else(|e| {
265+
log::debug!(
266+
target: "datafusion_python::udaf",
267+
"PythonFunctionAggregateUDF {:?} __eq__ raised; treating as unequal: {e}",
268+
self.name,
269+
);
270+
false
271+
})
272+
}))
273+
}
274+
}
275+
276+
impl std::hash::Hash for PythonFunctionAggregateUDF {
277+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
278+
// See `PythonFunctionScalarUDF`'s `Hash` impl for the
279+
// rationale: hash the identifying header only and let
280+
// `PartialEq` disambiguate callables.
281+
self.name.hash(state);
282+
self.signature.hash(state);
283+
self.return_type.hash(state);
284+
for f in &self.state_fields {
285+
f.hash(state);
286+
}
287+
}
288+
}
289+
290+
impl AggregateUDFImpl for PythonFunctionAggregateUDF {
291+
fn name(&self) -> &str {
292+
&self.name
293+
}
294+
295+
fn signature(&self) -> &Signature {
296+
&self.signature
297+
}
298+
299+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
300+
Ok(self.return_type.clone())
301+
}
302+
303+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
304+
instantiate_accumulator(&self.accumulator)
305+
}
306+
307+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
308+
Ok(self.state_fields.clone())
309+
}
156310
}
157311

158312
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
@@ -190,14 +344,15 @@ impl PyAggregateUDF {
190344
state_type: PyArrowType<Vec<DataType>>,
191345
volatility: &str,
192346
) -> PyResult<Self> {
193-
let function = create_udaf(
194-
name,
347+
let py_udf = PythonFunctionAggregateUDF::new(
348+
name.to_string(),
349+
accumulator,
195350
input_type.0,
196-
Arc::new(return_type.0),
351+
return_type.0,
352+
state_type.0,
197353
parse_volatility(volatility)?,
198-
to_rust_accumulator(accumulator),
199-
Arc::new(state_type.0),
200354
);
355+
let function = AggregateUDF::new_from_impl(py_udf);
201356
Ok(Self { function })
202357
}
203358

@@ -231,4 +386,9 @@ impl PyAggregateUDF {
231386
fn __repr__(&self) -> PyResult<String> {
232387
Ok(format!("AggregateUDF({})", self.function.name()))
233388
}
389+
390+
#[getter]
391+
fn name(&self) -> &str {
392+
self.function.name()
393+
}
234394
}

0 commit comments

Comments
 (0)