Skip to content

Commit 061f3ab

Browse files
timsaucerclaude
andcommitted
feat: pickle support for Expr via inline scalar UDF encoding
Adds Python-aware encoding to PythonLogicalCodec/PythonPhysicalCodec so a ScalarUDF defined in Python travels inside the serialized expression (cloudpickled into fun_definition) instead of needing a matching registration on the receiver. With that in place, Expr gains __reduce__ + classmethod from_bytes(buf, ctx=None) so pickle.dumps / pickle.loads work end-to-end on expressions built from col, lit, built-in functions, and Python scalar UDFs. Wire format is framed as <DFPYUDF magic, version byte, cloudpickle tuple>; the version byte lets a too-new/too-old payload surface a clean Execution error instead of an opaque cloudpickle unpack failure. Schema serde is via arrow-rs's native IPC (no pyarrow round-trip). Cloudpickle module handle is cached per-interpreter through PyOnceLock. Worker-side context resolution lives in a new datafusion.ipc module: set_worker_ctx / get_worker_ctx / clear_worker_ctx plus a private _resolve_ctx helper consulted by Expr.from_bytes. Priority is explicit ctx > worker ctx > global SessionContext. FFI UDFs still travel by name and require the matching registration on the receiver's context. Aggregate and window UDF inline encoding, the per-session with_python_udf_inlining toggle, sender-side context, and the user-guide docs land in follow-on PRs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent baef8f0 commit 061f3ab

10 files changed

Lines changed: 788 additions & 57 deletions

File tree

crates/core/src/codec.rs

Lines changed: 394 additions & 34 deletions
Large diffs are not rendered by default.

crates/core/src/udf.rs

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use crate::expr::PyExpr;
4343
/// This struct holds the Python written function that is a
4444
/// ScalarUDF.
4545
#[derive(Debug)]
46-
struct PythonFunctionScalarUDF {
46+
pub(crate) struct PythonFunctionScalarUDF {
4747
name: String,
4848
func: Py<PyAny>,
4949
signature: Signature,
@@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF {
6767
return_field: Arc::new(return_field),
6868
}
6969
}
70+
71+
/// Stored Python callable. Consumed by the codec to cloudpickle
72+
/// the function body across process boundaries.
73+
pub(crate) fn func(&self) -> &Py<PyAny> {
74+
&self.func
75+
}
76+
77+
pub(crate) fn return_field(&self) -> &FieldRef {
78+
&self.return_field
79+
}
80+
81+
/// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted
82+
/// by the codec. Inputs collapse to `Vec<DataType>` because
83+
/// `Signature::exact` cannot carry per-input nullability or
84+
/// metadata — the encoder is free to discard that side of the
85+
/// schema. `return_field` is kept as a `Field` so the post-decode
86+
/// nullability and metadata match the sender's instance.
87+
pub(crate) fn from_parts(
88+
name: String,
89+
func: Py<PyAny>,
90+
input_types: Vec<DataType>,
91+
return_field: Field,
92+
volatility: Volatility,
93+
) -> Self {
94+
Self {
95+
name,
96+
func,
97+
signature: Signature::exact(input_types, volatility),
98+
return_field: Arc::new(return_field),
99+
}
100+
}
70101
}
71102

72103
impl Eq for PythonFunctionScalarUDF {}
@@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF {
75106
self.name == other.name
76107
&& self.signature == other.signature
77108
&& self.return_field == other.return_field
78-
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
109+
// Identical pointers ⇒ same Python object. Most equality
110+
// checks compare `Arc`-shared clones of the same UDF
111+
// (e.g. expression rewriting), so the pointer match short-
112+
// circuits before touching the GIL.
113+
&& (self.func.as_ptr() == other.func.as_ptr()
114+
|| Python::attach(|py| {
115+
// Rust's `PartialEq` cannot return `Result`, so we
116+
// have to pick a side when Python `__eq__` raises.
117+
// `false` is the conservative choice — better to
118+
// report two UDFs as distinct than to wrongly
119+
// merge them — but the silent miss can still
120+
// surface as expression-dedup or cache-lookup
121+
// anomalies. Log at `debug` so the failure is
122+
// observable without flooding production logs.
123+
// FIXME: revisit if upstream `ScalarUDFImpl`
124+
// exposes a fallible `PartialEq`.
125+
self.func
126+
.bind(py)
127+
.eq(other.func.bind(py))
128+
.unwrap_or_else(|e| {
129+
log::debug!(
130+
target: "datafusion_python::udf",
131+
"PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}",
132+
self.name,
133+
);
134+
false
135+
})
136+
}))
79137
}
80138
}
81139

82140
impl Hash for PythonFunctionScalarUDF {
83141
fn hash<H: Hasher>(&self, state: &mut H) {
142+
// Hash only the identifying header (name + signature + return
143+
// field). Skipping `func` is intentional: the Rust `Hash`
144+
// contract requires `a == b ⇒ hash(a) == hash(b)`, not the
145+
// converse, so a coarser hash is sound — `PartialEq` still
146+
// disambiguates two UDFs with the same header but distinct
147+
// callables. Falling back to a sentinel on `py_hash` failure
148+
// (as a prior revision did) silently mapped every unhashable
149+
// closure to the same bucket; that is the worst case for a
150+
// hashmap and is what this rewrite avoids.
84151
self.name.hash(state);
85152
self.signature.hash(state);
86153
self.return_field.hash(state);
87-
88-
Python::attach(|py| {
89-
let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects
90-
91-
state.write_isize(py_hash);
92-
});
93154
}
94155
}
95156

@@ -220,4 +281,9 @@ impl PyScalarUDF {
220281
fn __repr__(&self) -> PyResult<String> {
221282
Ok(format!("ScalarUDF({})", self.function.name()))
222283
}
284+
285+
#[getter]
286+
fn name(&self) -> &str {
287+
self.function.name()
288+
}
223289
}

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ classifiers = [
4444
"Programming Language :: Rust",
4545
]
4646
dependencies = [
47+
# cloudpickle is invoked by the Rust-side PythonLogicalCodec /
48+
# PythonPhysicalCodec via pyo3 to serialize Python UDF callables —
49+
# scalar, aggregate, and window — into the proto wire format.
50+
# Lazy-imported on the encode / decode hot paths (and cached after
51+
# the first import), so users who never serialize a plan or
52+
# expression incur no runtime cost beyond the install footprint.
53+
"cloudpickle>=2.0",
4754
"pyarrow>=16.0.0;python_version<'3.14'",
4855
"pyarrow>=22.0.0;python_version>='3.14'",
4956
"typing-extensions;python_version<'3.13'",

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
import importlib_metadata # type: ignore[import]
6666

6767
# Public submodules
68-
from . import functions, object_store, substrait, unparser
68+
from . import functions, ipc, object_store, substrait, unparser
6969

7070
# The following imports are okay to remain as opaque to the user.
7171
from ._internal import Config
@@ -142,6 +142,7 @@
142142
"configure_formatter",
143143
"expr",
144144
"functions",
145+
"ipc",
145146
"lit",
146147
"literal",
147148
"object_store",

python/datafusion/expr.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -434,23 +434,59 @@ def variant_name(self) -> str:
434434
return self.expr.variant_name()
435435

436436
def to_bytes(self, ctx: SessionContext | None = None) -> bytes:
437-
"""Serialize this expression to protobuf bytes.
437+
"""Serialize this expression to bytes for shipping to another process.
438438
439-
When ``ctx`` is supplied, encoding routes through the session's
440-
installed :class:`LogicalExtensionCodec`. Without ``ctx`` a
441-
default codec is used.
439+
Use this — or :func:`pickle.dumps` — to send an expression to a
440+
worker process for distributed evaluation.
441+
442+
When ``ctx`` is supplied, encoding routes through that session's
443+
installed :class:`LogicalExtensionCodec`. When ``ctx`` is
444+
``None``, the default codec is used.
445+
446+
Built-in functions and Python scalar UDFs travel inside the
447+
returned bytes; the worker does not need to pre-register them.
448+
UDFs imported via the FFI capsule protocol travel by name only
449+
and must be registered on the worker.
442450
"""
443451
ctx_arg = ctx.ctx if ctx is not None else None
444452
return self.expr.to_bytes(ctx_arg)
445453

446-
@staticmethod
447-
def from_bytes(ctx: SessionContext, data: bytes) -> Expr:
448-
"""Decode an expression from serialized protobuf bytes.
449-
450-
``ctx`` provides the function registry for resolving UDF
451-
references and the logical codec for in-band Python payloads.
454+
@classmethod
455+
def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr:
456+
"""Reconstruct an expression from serialized bytes.
457+
458+
Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`.
459+
``ctx`` is the :class:`SessionContext` used to resolve any
460+
function references that travel by name (e.g. FFI UDFs). When
461+
``ctx`` is ``None`` the worker context installed via
462+
:func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker
463+
context is installed, the global :class:`SessionContext` is used
464+
(sufficient for built-ins and Python scalar UDFs, plus any UDFs
465+
registered on the global context).
466+
"""
467+
from datafusion.ipc import _resolve_ctx
468+
469+
resolved = _resolve_ctx(ctx)
470+
return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf))
471+
472+
def __reduce__(self) -> tuple:
473+
"""Pickle protocol hook.
474+
475+
Lets expressions be shipped to worker processes via
476+
:func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions
477+
and Python scalar UDFs travel inside the pickle bytes; only
478+
FFI-capsule UDFs require pre-registration on the worker. The
479+
worker's :class:`SessionContext` for resolving those references
480+
is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling
481+
back to the global :class:`SessionContext` if none has been
482+
installed on the worker.
452483
"""
453-
return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data))
484+
return (Expr._reconstruct, (self.to_bytes(),))
485+
486+
@classmethod
487+
def _reconstruct(cls, proto_bytes: bytes) -> Expr:
488+
"""Internal entry point used by :meth:`__reduce__` on unpickle."""
489+
return cls.from_bytes(proto_bytes)
454490

455491
def __richcmp__(self, other: Expr, op: int) -> Expr:
456492
"""Comparison operator."""

python/datafusion/ipc.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Worker-side setup for distributing DataFusion expressions.
19+
20+
When a :class:`Expr` is shipped to a worker process (e.g. through
21+
:func:`multiprocessing.Pool` or a Ray actor), the worker reconstructs the
22+
expression against a :class:`SessionContext`. If the expression references
23+
UDFs imported via the FFI capsule protocol — or any UDF the worker would
24+
otherwise resolve from its registered functions rather than from inside
25+
the shipped expression — install a configured :class:`SessionContext`
26+
once per worker:
27+
28+
.. code-block:: python
29+
30+
from datafusion import SessionContext
31+
from datafusion.ipc import set_worker_ctx
32+
33+
def init_worker():
34+
ctx = SessionContext()
35+
ctx.register_udaf(my_ffi_aggregate)
36+
set_worker_ctx(ctx)
37+
38+
Built-in functions and Python scalar UDFs travel inside the shipped
39+
expression itself and do not need pre-registration on the worker.
40+
"""
41+
42+
from __future__ import annotations
43+
44+
import threading
45+
from typing import TYPE_CHECKING
46+
47+
if TYPE_CHECKING:
48+
from datafusion.context import SessionContext
49+
50+
51+
__all__ = [
52+
"clear_worker_ctx",
53+
"get_worker_ctx",
54+
"set_worker_ctx",
55+
]
56+
57+
58+
_local = threading.local()
59+
60+
61+
def set_worker_ctx(ctx: SessionContext) -> None:
62+
"""Install this worker's :class:`SessionContext` for shipped expressions.
63+
64+
Call once per worker — typically from a ``multiprocessing.Pool``
65+
initializer or a Ray actor ``__init__``. Idempotent: overwrites any
66+
previous value. Stored in a thread-local slot, so each thread within a
67+
worker may install its own context independently.
68+
"""
69+
_local.ctx = ctx
70+
71+
72+
def clear_worker_ctx() -> None:
73+
"""Remove this worker's installed :class:`SessionContext`.
74+
75+
After clearing, expressions reconstructed in this worker fall back to
76+
the global :class:`SessionContext` — adequate for built-ins and Python
77+
scalar UDFs, but anything imported via the FFI capsule protocol must
78+
be registered on the global context to resolve.
79+
"""
80+
if hasattr(_local, "ctx"):
81+
del _local.ctx
82+
83+
84+
def get_worker_ctx() -> SessionContext | None:
85+
"""Return this worker's installed :class:`SessionContext`, or ``None``."""
86+
return getattr(_local, "ctx", None)
87+
88+
89+
def _resolve_ctx(
90+
explicit_ctx: SessionContext | None = None,
91+
) -> SessionContext:
92+
"""Resolve a context for Expr reconstruction.
93+
94+
Priority: explicit argument > worker context > global context.
95+
Falling back to the global :class:`SessionContext` (instead of a
96+
freshly constructed one) preserves any registrations the user has
97+
installed on it.
98+
"""
99+
if explicit_ctx is not None:
100+
return explicit_ctx
101+
worker = get_worker_ctx()
102+
if worker is not None:
103+
return worker
104+
# Lazy import: `datafusion/__init__.py` imports `datafusion.ipc`
105+
# before `datafusion.context`, so a module-top import would force
106+
# `datafusion.context` to load mid-init of `datafusion.ipc`. The
107+
# cycle is benign today (context.py only pulls expr.py at module
108+
# scope, neither pulls ipc.py back), but a single new import in
109+
# context.py's transitive deps could turn it into a real cycle.
110+
# Deferring keeps `datafusion.ipc` import-order-independent.
111+
from datafusion.context import SessionContext # noqa: PLC0415
112+
113+
return SessionContext.global_ctx()

python/datafusion/user_defined.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,16 @@ def __init__(
141141
name, func, input_fields, return_field, str(volatility)
142142
)
143143

144+
@property
145+
def name(self) -> str:
146+
"""Return the registered name of this UDF.
147+
148+
For UDFs imported via the FFI capsule protocol, this is the
149+
name the capsule itself reports — not the ``name`` argument
150+
passed to the constructor (which is ignored on the FFI path).
151+
"""
152+
return self._udf.name
153+
144154
def __repr__(self) -> str:
145155
"""Print a string representation of the Scalar UDF."""
146156
return self._udf.__repr__()

python/tests/test_expr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None:
11861186

11871187
original = col("a") + lit(1)
11881188
blob = original.to_bytes(ctx)
1189-
restored = Expr.from_bytes(ctx, blob)
1189+
restored = Expr.from_bytes(blob, ctx=ctx)
11901190

11911191
# Canonical name preserves the structure of the expression even
11921192
# though the underlying PyExpr instances are different.
@@ -1201,6 +1201,6 @@ def test_expr_to_bytes_no_ctx_default_codec() -> None:
12011201
fresh = SessionContext()
12021202
original = col("a") * lit(2)
12031203
blob = original.to_bytes() # encode side: default codec
1204-
restored = Expr.from_bytes(fresh, blob)
1204+
restored = Expr.from_bytes(blob, ctx=fresh)
12051205

12061206
assert restored.canonical_name() == original.canonical_name()

0 commit comments

Comments
 (0)