Skip to content

Commit afaeccb

Browse files
timsaucerclaude
andauthored
feat: enable pickling of most Expr except udaf and udwf (#1544)
* feat: pickle support for Expr via inline scalar UDF encoding Adds Python-aware encoding to PythonLogicalCodec/PythonPhysicalCodec so a ScalarUDF defined in Python travels inside the serialized expression (cloudpickled into fun_definition) instead of needing a matching registration on the receiver. With that in place, Expr gains __reduce__ + classmethod from_bytes(buf, ctx=None) so pickle.dumps / pickle.loads work end-to-end on expressions built from col, lit, built-in functions, and Python scalar UDFs. Wire format is framed as <DFPYUDF magic, version byte, cloudpickle tuple>; the version byte lets a too-new/too-old payload surface a clean Execution error instead of an opaque cloudpickle unpack failure. Schema serde is via arrow-rs's native IPC (no pyarrow round-trip). Cloudpickle module handle is cached per-interpreter through PyOnceLock. Worker-side context resolution lives in a new datafusion.ipc module: set_worker_ctx / get_worker_ctx / clear_worker_ctx plus a private _resolve_ctx helper consulted by Expr.from_bytes. Priority is explicit ctx > worker ctx > global SessionContext. FFI UDFs still travel by name and require the matching registration on the receiver's context. Aggregate and window UDF inline encoding, the per-session with_python_udf_inlining toggle, sender-side context, and the user-guide docs land in follow-on PRs. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * docs(pickle): add cloudpickle security warnings, docstring examples, edge-case tests Inline `.. warning::` blocks on `Expr.to_bytes`, `Expr.from_bytes`, and `Expr.__reduce__` so the cloudpickle / arbitrary-code-execution caveat is visible at the public API surface in advance of the user-guide page that lands in PR 4. Add doctest-style `Examples:` blocks to `datafusion.ipc` functions (`set_worker_ctx`, `clear_worker_ctx`, `get_worker_ctx`, `_resolve_ctx`), `ScalarUDF.name`, and the new `Expr` pickle methods, per CLAUDE.md. Tighten `Expr.__reduce__` return annotation to `tuple[Callable[[bytes], Expr], tuple[bytes]]`. Tests: multi-arg UDF round-trip (covers synthetic `arg_{i}` schema-field loop in the codec) plus malformed-bytes paths through `Expr.from_bytes`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * as_any no longer in api * feat(pickle): stamp Python (major, minor) in UDF wire header cloudpickle bytecode is not portable across Python minor versions — a payload produced on 3.11 fails to load on 3.12 with an opaque marshal/unpickle error. Embed the sender's (major, minor) in the DFPYUDF wire header and reject mismatches at decode time with an actionable error that names both versions, instead of letting the failure surface from inside cloudpickle.loads. Header layout becomes: DFPYUDF (7) | version (1) | py_major (1) | py_minor (1) | cloudpickle Extend the Security warnings on Expr.to_bytes / from_bytes / __reduce__ with a Portability section covering the cross-version constraint and cloudpickle's by-value/by-reference behavior (the callable inlines bytecode and closure cells, but imported names travel by reference and must be importable on the receiver). Add a matching Serialization model note to the datafusion.ipc module docstring. New tests: - codec::wire_header_tests: py-major/minor mismatch, truncated py-version bytes, round-trip with py-version - test_pickle_expr::test_cross_version_error_message: patches the py_minor byte inside an emitted payload and asserts the error message identifies the version mismatch Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 8ba06e4 commit afaeccb

10 files changed

Lines changed: 1108 additions & 57 deletions

File tree

crates/core/src/codec.rs

Lines changed: 490 additions & 34 deletions
Large diffs are not rendered by default.

crates/core/src/udf.rs

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use crate::expr::PyExpr;
4242
/// This struct holds the Python written function that is a
4343
/// ScalarUDF.
4444
#[derive(Debug)]
45-
struct PythonFunctionScalarUDF {
45+
pub(crate) struct PythonFunctionScalarUDF {
4646
name: String,
4747
func: Py<PyAny>,
4848
signature: Signature,
@@ -66,6 +66,37 @@ impl PythonFunctionScalarUDF {
6666
return_field: Arc::new(return_field),
6767
}
6868
}
69+
70+
/// Stored Python callable. Consumed by the codec to cloudpickle
71+
/// the function body across process boundaries.
72+
pub(crate) fn func(&self) -> &Py<PyAny> {
73+
&self.func
74+
}
75+
76+
pub(crate) fn return_field(&self) -> &FieldRef {
77+
&self.return_field
78+
}
79+
80+
/// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted
81+
/// by the codec. Inputs collapse to `Vec<DataType>` because
82+
/// `Signature::exact` cannot carry per-input nullability or
83+
/// metadata — the encoder is free to discard that side of the
84+
/// schema. `return_field` is kept as a `Field` so the post-decode
85+
/// nullability and metadata match the sender's instance.
86+
pub(crate) fn from_parts(
87+
name: String,
88+
func: Py<PyAny>,
89+
input_types: Vec<DataType>,
90+
return_field: Field,
91+
volatility: Volatility,
92+
) -> Self {
93+
Self {
94+
name,
95+
func,
96+
signature: Signature::exact(input_types, volatility),
97+
return_field: Arc::new(return_field),
98+
}
99+
}
69100
}
70101

71102
impl Eq for PythonFunctionScalarUDF {}
@@ -74,21 +105,51 @@ impl PartialEq for PythonFunctionScalarUDF {
74105
self.name == other.name
75106
&& self.signature == other.signature
76107
&& self.return_field == other.return_field
77-
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
108+
// Identical pointers ⇒ same Python object. Most equality
109+
// checks compare `Arc`-shared clones of the same UDF
110+
// (e.g. expression rewriting), so the pointer match short-
111+
// circuits before touching the GIL.
112+
&& (self.func.as_ptr() == other.func.as_ptr()
113+
|| Python::attach(|py| {
114+
// Rust's `PartialEq` cannot return `Result`, so we
115+
// have to pick a side when Python `__eq__` raises.
116+
// `false` is the conservative choice — better to
117+
// report two UDFs as distinct than to wrongly
118+
// merge them — but the silent miss can still
119+
// surface as expression-dedup or cache-lookup
120+
// anomalies. Log at `debug` so the failure is
121+
// observable without flooding production logs.
122+
// FIXME: revisit if upstream `ScalarUDFImpl`
123+
// exposes a fallible `PartialEq`.
124+
self.func
125+
.bind(py)
126+
.eq(other.func.bind(py))
127+
.unwrap_or_else(|e| {
128+
log::debug!(
129+
target: "datafusion_python::udf",
130+
"PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}",
131+
self.name,
132+
);
133+
false
134+
})
135+
}))
78136
}
79137
}
80138

81139
impl Hash for PythonFunctionScalarUDF {
82140
fn hash<H: Hasher>(&self, state: &mut H) {
141+
// Hash only the identifying header (name + signature + return
142+
// field). Skipping `func` is intentional: the Rust `Hash`
143+
// contract requires `a == b ⇒ hash(a) == hash(b)`, not the
144+
// converse, so a coarser hash is sound — `PartialEq` still
145+
// disambiguates two UDFs with the same header but distinct
146+
// callables. Falling back to a sentinel on `py_hash` failure
147+
// (as a prior revision did) silently mapped every unhashable
148+
// closure to the same bucket; that is the worst case for a
149+
// hashmap and is what this rewrite avoids.
83150
self.name.hash(state);
84151
self.signature.hash(state);
85152
self.return_field.hash(state);
86-
87-
Python::attach(|py| {
88-
let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects
89-
90-
state.write_isize(py_hash);
91-
});
92153
}
93154
}
94155

@@ -215,4 +276,9 @@ impl PyScalarUDF {
215276
fn __repr__(&self) -> PyResult<String> {
216277
Ok(format!("ScalarUDF({})", self.function.name()))
217278
}
279+
280+
#[getter]
281+
fn name(&self) -> &str {
282+
self.function.name()
283+
}
218284
}

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ classifiers = [
4444
"Programming Language :: Rust",
4545
]
4646
dependencies = [
47+
# cloudpickle is invoked by the Rust-side PythonLogicalCodec /
48+
# PythonPhysicalCodec via pyo3 to serialize Python UDF callables —
49+
# scalar, aggregate, and window — into the proto wire format.
50+
# Lazy-imported on the encode / decode hot paths (and cached after
51+
# the first import), so users who never serialize a plan or
52+
# expression incur no runtime cost beyond the install footprint.
53+
"cloudpickle>=2.0",
4754
"pyarrow>=16.0.0;python_version<'3.14'",
4855
"pyarrow>=22.0.0;python_version>='3.14'",
4956
"typing-extensions;python_version<'3.13'",

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
import importlib_metadata # type: ignore[import]
6666

6767
# Public submodules
68-
from . import functions, object_store, substrait, unparser
68+
from . import functions, ipc, object_store, substrait, unparser
6969

7070
# The following imports are okay to remain as opaque to the user.
7171
from ._internal import Config
@@ -149,6 +149,7 @@
149149
"configure_formatter",
150150
"expr",
151151
"functions",
152+
"ipc",
152153
"lit",
153154
"literal",
154155
"object_store",

python/datafusion/expr.py

Lines changed: 153 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
from __future__ import annotations
4848

49-
from collections.abc import Iterable, Sequence
49+
from collections.abc import Callable, Iterable, Sequence
5050
from typing import TYPE_CHECKING, Any, ClassVar
5151

5252
import pyarrow as pa
@@ -440,23 +440,165 @@ def variant_name(self) -> str:
440440
return self.expr.variant_name()
441441

442442
def to_bytes(self, ctx: SessionContext | None = None) -> bytes:
443-
"""Serialize this expression to protobuf bytes.
443+
"""Serialize this expression to bytes for shipping to another process.
444444
445-
When ``ctx`` is supplied, encoding routes through the session's
446-
installed :class:`LogicalExtensionCodec`. Without ``ctx`` a
447-
default codec is used.
445+
Use this — or :func:`pickle.dumps` — to send an expression to a
446+
worker process for distributed evaluation.
447+
448+
When ``ctx`` is supplied, encoding routes through that session's
449+
installed :class:`LogicalExtensionCodec`. When ``ctx`` is
450+
``None``, the default codec is used.
451+
452+
Built-in functions and Python scalar UDFs travel inside the
453+
returned bytes; the worker does not need to pre-register them.
454+
UDFs imported via the FFI capsule protocol travel by name only
455+
and must be registered on the worker.
456+
457+
.. warning:: Security
458+
Bytes returned here may embed a cloudpickled Python
459+
callable (when the expression carries a Python scalar UDF).
460+
Reconstructing them via :meth:`from_bytes` or
461+
:func:`pickle.loads` executes arbitrary Python on the
462+
receiver. Only accept payloads from trusted sources.
463+
464+
.. warning:: Portability
465+
cloudpickle serializes Python bytecode, which is **not
466+
stable across Python minor versions**. A payload produced
467+
on Python 3.11 will fail to load on Python 3.12. The
468+
wire format stamps the sender's ``(major, minor)``;
469+
:meth:`from_bytes` raises a :class:`ValueError` naming
470+
both versions on mismatch.
471+
472+
cloudpickle captures the UDF callable **by value** —
473+
bytecode and closure cells inlined — but names the
474+
callable resolves via ``import`` are captured **by
475+
reference** (module path only) and must be importable on
476+
the receiver.
477+
478+
**Self-contained — works anywhere:**
479+
480+
.. code-block:: python
481+
482+
# Lambda: bytecode captured inline
483+
udf(lambda x: x * 2, [pa.int64()], pa.int64(),
484+
volatility="immutable")
485+
486+
# Locally-defined function: bytecode captured inline
487+
def double(x):
488+
return x * 2
489+
udf(double, [pa.int64()], pa.int64(), volatility="immutable")
490+
491+
# Closure over a local variable: value captured inline
492+
factor = 3
493+
udf(lambda x: x * factor, [pa.int64()], pa.int64(),
494+
volatility="immutable")
495+
496+
**Requires matching environment on receiver:**
497+
498+
.. code-block:: python
499+
500+
# Top-level import: `foo` must be installed on receiver
501+
from foo import double
502+
udf(double, [pa.int64()], pa.int64(), volatility="immutable")
503+
504+
# Bound method of an imported class: same caveat
505+
from mylib import Transformer
506+
t = Transformer()
507+
udf(t.transform, [pa.int64()], pa.int64(),
508+
volatility="immutable")
509+
510+
Examples:
511+
>>> from datafusion import col, lit
512+
>>> blob = (col("a") + lit(1)).to_bytes()
513+
>>> isinstance(blob, bytes)
514+
True
448515
"""
449516
ctx_arg = ctx.ctx if ctx is not None else None
450517
return self.expr.to_bytes(ctx_arg)
451518

452-
@staticmethod
453-
def from_bytes(ctx: SessionContext, data: bytes) -> Expr:
454-
"""Decode an expression from serialized protobuf bytes.
519+
@classmethod
520+
def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr:
521+
"""Reconstruct an expression from serialized bytes.
522+
523+
Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`.
524+
``ctx`` is the :class:`SessionContext` used to resolve any
525+
function references that travel by name (e.g. FFI UDFs). When
526+
``ctx`` is ``None`` the worker context installed via
527+
:func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker
528+
context is installed, the global :class:`SessionContext` is used
529+
(sufficient for built-ins and Python scalar UDFs, plus any UDFs
530+
registered on the global context).
531+
532+
.. warning:: Security
533+
Decoding may invoke ``cloudpickle.loads`` on bytes embedded
534+
in the payload, which executes arbitrary Python code. Treat
535+
``buf`` as code, not data — only decode bytes you produced
536+
yourself or received from a trusted sender.
537+
538+
.. warning:: Portability
539+
cloudpickle payloads are **not portable across Python
540+
minor versions**. The wire format stamps the sender's
541+
``(major, minor)``; if it does not match the current
542+
interpreter, this method raises :class:`ValueError`
543+
naming both versions. Modules the UDF imports must also
544+
be importable on the receiver — see :meth:`to_bytes` for
545+
by-value vs. by-reference details.
546+
547+
Examples:
548+
>>> from datafusion import Expr, col, lit
549+
>>> blob = (col("a") + lit(1)).to_bytes()
550+
>>> Expr.from_bytes(blob).canonical_name()
551+
'a + Int64(1)'
552+
"""
553+
from datafusion.ipc import _resolve_ctx
554+
555+
resolved = _resolve_ctx(ctx)
556+
return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf))
557+
558+
def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]:
559+
"""Pickle protocol hook.
560+
561+
Lets expressions be shipped to worker processes via
562+
:func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions
563+
and Python scalar UDFs travel inside the pickle bytes; only
564+
FFI-capsule UDFs require pre-registration on the worker. The
565+
worker's :class:`SessionContext` for resolving those references
566+
is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling
567+
back to the global :class:`SessionContext` if none has been
568+
installed on the worker.
569+
570+
.. warning:: Security
571+
:func:`pickle.loads` on the returned tuple executes
572+
arbitrary Python on the receiver, including any
573+
cloudpickled UDF callable embedded in the payload. Only
574+
unpickle expressions from trusted sources.
575+
576+
.. warning:: Portability
577+
Sender and receiver must run the same Python
578+
``(major, minor)`` version; cloudpickle bytecode is not
579+
portable across minor versions. See :meth:`to_bytes` for
580+
details on what travels by value vs. by reference.
581+
582+
Examples:
583+
>>> import pickle
584+
>>> from datafusion import col, lit
585+
>>> e = col("a") * lit(2)
586+
>>> pickle.loads(pickle.dumps(e)).canonical_name()
587+
'a * Int64(2)'
588+
"""
589+
return (Expr._reconstruct, (self.to_bytes(),))
590+
591+
@classmethod
592+
def _reconstruct(cls, proto_bytes: bytes) -> Expr:
593+
"""Internal entry point used by :meth:`__reduce__` on unpickle.
455594
456-
``ctx`` provides the function registry for resolving UDF
457-
references and the logical codec for in-band Python payloads.
595+
Examples:
596+
>>> from datafusion import Expr, col, lit
597+
>>> blob = (col("a") + lit(1)).to_bytes()
598+
>>> Expr._reconstruct(blob).canonical_name()
599+
'a + Int64(1)'
458600
"""
459-
return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data))
601+
return cls.from_bytes(proto_bytes)
460602

461603
def __richcmp__(self, other: Expr, op: int) -> Expr:
462604
"""Comparison operator."""

0 commit comments

Comments
 (0)