Skip to content

Commit 4a3237f

Browse files
timsaucerclaude
andcommitted
feat: restore pub UDAF/UDWF helpers and document inline encoding
Re-export `to_rust_accumulator`, `to_rust_partition_evaluator`, and `PythonFunctionWindowUDF` (with a `MultiColumnWindowUDF` alias) by promoting `udaf` and `udwf` to `pub mod` so prior downstream Rust consumers keep their API surface after the inline-encoding refactor. Adds an end-to-end window UDF pickle round-trip test that runs the decoded evaluator over a real session, mirroring the aggregate test. Documents the cloudpickle-based shipping behavior of Python aggregate and window UDFs in the user-guide aggregations and windows pages. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d5bb146 commit 4a3237f

6 files changed

Lines changed: 91 additions & 6 deletions

File tree

crates/core/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ mod array;
5959
#[cfg(feature = "substrait")]
6060
pub mod substrait;
6161
#[allow(clippy::borrow_deref_ref)]
62-
mod udaf;
62+
pub mod udaf;
6363
#[allow(clippy::borrow_deref_ref)]
6464
mod udf;
6565
pub mod udtf;
66-
mod udwf;
66+
pub mod udwf;
6767

6868
#[cfg(feature = "mimalloc")]
6969
#[global_allocator]

crates/core/src/udaf.rs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use datafusion::common::ScalarValue;
2525
use datafusion::error::{DataFusionError, Result};
2626
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
2727
use datafusion::logical_expr::{
28-
Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
28+
Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, Signature, Volatility,
2929
};
3030
use datafusion_ffi::udaf::FFI_AggregateUDF;
3131
use datafusion_python_util::parse_volatility;
@@ -154,6 +154,17 @@ fn instantiate_accumulator(accum: &Py<PyAny>) -> Result<Box<dyn Accumulator>> {
154154
Ok(Box::new(RustAccumulator::new(instance)))
155155
}
156156

157+
/// Wrap a Python accumulator factory in an `AccumulatorFactoryFunction`.
158+
///
159+
/// Retained for downstream callers that previously consumed this
160+
/// helper to build a [`AccumulatorFactoryFunction`] for `create_udaf`
161+
/// or similar factory-based APIs. New in-crate code should construct
162+
/// a [`PythonFunctionAggregateUDF`] directly so the codec can downcast
163+
/// and ship it inline.
164+
pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
165+
Arc::new(move |_args| instantiate_accumulator(&accum))
166+
}
167+
157168
/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs.
158169
/// Holds the Python accumulator factory directly so the codec can
159170
/// downcast and cloudpickle it across process boundaries.

crates/core/src/udwf.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use datafusion::error::{DataFusionError, Result};
2626
use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs};
2727
use datafusion::logical_expr::window_state::WindowAggState;
2828
use datafusion::logical_expr::{
29-
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
29+
PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl,
3030
};
3131
use datafusion::scalar::ScalarValue;
3232
use datafusion_ffi::udwf::FFI_WindowUDF;
@@ -205,6 +205,17 @@ fn instantiate_partition_evaluator(evaluator: &Py<PyAny>) -> Result<Box<dyn Part
205205
Ok(Box::new(RustPartitionEvaluator::new(instance)))
206206
}
207207

208+
/// Wrap a Python evaluator factory in a `PartitionEvaluatorFactory`.
209+
///
210+
/// Retained for downstream callers that previously consumed this
211+
/// helper to build a [`PartitionEvaluatorFactory`] for factory-based
212+
/// APIs. New in-crate code should construct a
213+
/// [`PythonFunctionWindowUDF`] directly so the codec can downcast and
214+
/// ship it inline.
215+
pub fn to_rust_partition_evaluator(evaluator: Py<PyAny>) -> PartitionEvaluatorFactory {
216+
Arc::new(move || instantiate_partition_evaluator(&evaluator))
217+
}
218+
208219
/// Represents an WindowUDF
209220
#[pyclass(
210221
from_py_object,
@@ -279,16 +290,26 @@ impl PyWindowUDF {
279290
}
280291
}
281292

293+
/// `WindowUDFImpl` for Python-defined window UDFs.
294+
///
295+
/// Holds the Python evaluator factory directly so the codec can
296+
/// downcast and cloudpickle it across process boundaries. Replaces
297+
/// the prior factory-erased `MultiColumnWindowUDF`; the old name is
298+
/// kept as a type alias below for backward compatibility.
282299
#[derive(Debug)]
283-
pub(crate) struct PythonFunctionWindowUDF {
300+
pub struct PythonFunctionWindowUDF {
284301
name: String,
285302
evaluator: Py<PyAny>,
286303
signature: Signature,
287304
return_type: DataType,
288305
}
289306

307+
/// Backward-compatible alias for downstream crates that referenced the
308+
/// previous struct name. New code should use [`PythonFunctionWindowUDF`].
309+
pub type MultiColumnWindowUDF = PythonFunctionWindowUDF;
310+
290311
impl PythonFunctionWindowUDF {
291-
pub(crate) fn new(
312+
pub fn new(
292313
name: impl Into<String>,
293314
evaluator: Py<PyAny>,
294315
input_types: Vec<DataType>,

docs/source/user-guide/common-operations/aggregations.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,3 +434,21 @@ The available aggregate functions are:
434434
- :py:meth:`datafusion.expr.GroupingSet.cube`
435435
- :py:meth:`datafusion.expr.GroupingSet.grouping_sets`
436436

437+
User-Defined Aggregate Functions
438+
--------------------------------
439+
440+
You can ship custom aggregations to the engine by subclassing
441+
:py:class:`~datafusion.user_defined.Accumulator` and registering it via
442+
:py:func:`~datafusion.udaf`. See :py:mod:`datafusion.user_defined` for
443+
the accumulator interface and worked examples.
444+
445+
.. note:: Serialization
446+
447+
Python aggregate UDFs travel inline inside pickled or
448+
:py:meth:`~datafusion.expr.Expr.to_bytes`-serialized expressions —
449+
the accumulator class is captured by value via :mod:`cloudpickle`,
450+
so worker processes do not need to pre-register the UDF. Any names
451+
the accumulator resolves via ``import`` are captured **by reference**
452+
and must be importable on the receiving worker. See
453+
:py:mod:`datafusion.ipc` for the full IPC model and security caveats.
454+

docs/source/user-guide/common-operations/windows.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,21 @@ The possible window functions are:
213213

214214
3. Aggregate Functions
215215
- All :ref:`Aggregation Functions<aggregation>` can be used as window functions.
216+
217+
User-Defined Window Functions
218+
-----------------------------
219+
220+
You can ship custom window functions to the engine by subclassing
221+
:py:class:`~datafusion.user_defined.WindowEvaluator` and registering it
222+
via :py:func:`~datafusion.udwf`. See :py:mod:`datafusion.user_defined`
223+
for the evaluator interface and worked examples.
224+
225+
.. note:: Serialization
226+
227+
Python window UDFs travel inline inside pickled or
228+
:py:meth:`~datafusion.expr.Expr.to_bytes`-serialized expressions —
229+
the evaluator class is captured by value via :mod:`cloudpickle`, so
230+
worker processes do not need to pre-register the UDF. Any names the
231+
evaluator resolves via ``import`` are captured **by reference** and
232+
must be importable on the receiving worker. See
233+
:py:mod:`datafusion.ipc` for the full IPC model and security caveats.

python/tests/test_pickle_expr.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,23 @@ def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self):
254254
decoded = pickle.loads(blob) # noqa: S301
255255
assert "count_up" in decoded.canonical_name()
256256

257+
def test_window_udf_evaluates_after_roundtrip(self):
258+
"""End-to-end: decoded window UDF runs and emits per-row values
259+
produced by the round-tripped evaluator factory."""
260+
from datafusion.expr import WindowFrame
261+
262+
u = self._build_window_udf()
263+
e = u(col("a"))
264+
decoded = pickle.loads(pickle.dumps(e)) # noqa: S301
265+
266+
ctx = SessionContext()
267+
df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]})
268+
framed = (
269+
decoded.window_frame(WindowFrame("rows", None, None)).build().alias("c")
270+
)
271+
out = df.select(framed).to_pydict()
272+
assert out["c"] == [0, 1, 2, 3, 4]
273+
257274

258275
class TestErrorPaths:
259276
def test_from_bytes_rejects_garbage(self):

0 commit comments

Comments
 (0)