Skip to content

Commit ca6849e

Browse files
timsaucerclaude
andcommitted
docs(pickle): add cloudpickle security warnings, docstring examples, edge-case tests
Inline `.. warning::` blocks on `Expr.to_bytes`, `Expr.from_bytes`, and `Expr.__reduce__` so the cloudpickle / arbitrary-code-execution caveat is visible at the public API surface in advance of the user-guide page that lands in PR 4. Add doctest-style `Examples:` blocks to `datafusion.ipc` functions (`set_worker_ctx`, `clear_worker_ctx`, `get_worker_ctx`, `_resolve_ctx`), `ScalarUDF.name`, and the new `Expr` pickle methods, per CLAUDE.md. Tighten `Expr.__reduce__` return annotation to `tuple[Callable[[bytes], Expr], tuple[bytes]]`. Tests: multi-arg UDF round-trip (covers synthetic `arg_{i}` schema-field loop in the codec) plus malformed-bytes paths through `Expr.from_bytes`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 061f3ab commit ca6849e

4 files changed

Lines changed: 125 additions & 4 deletions

File tree

python/datafusion/expr.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
from __future__ import annotations
4848

49-
from collections.abc import Iterable, Sequence
49+
from collections.abc import Callable, Iterable, Sequence
5050
from typing import TYPE_CHECKING, Any, ClassVar
5151

5252
import pyarrow as pa
@@ -447,6 +447,19 @@ def to_bytes(self, ctx: SessionContext | None = None) -> bytes:
447447
returned bytes; the worker does not need to pre-register them.
448448
UDFs imported via the FFI capsule protocol travel by name only
449449
and must be registered on the worker.
450+
451+
.. warning::
452+
Bytes returned here may embed a cloudpickled Python
453+
callable (when the expression carries a Python scalar UDF).
454+
Reconstructing them via :meth:`from_bytes` or
455+
:func:`pickle.loads` executes arbitrary Python on the
456+
receiver. Only accept payloads from trusted sources.
457+
458+
Examples:
459+
>>> from datafusion import col, lit
460+
>>> blob = (col("a") + lit(1)).to_bytes()
461+
>>> isinstance(blob, bytes)
462+
True
450463
"""
451464
ctx_arg = ctx.ctx if ctx is not None else None
452465
return self.expr.to_bytes(ctx_arg)
@@ -463,13 +476,25 @@ def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr:
463476
context is installed, the global :class:`SessionContext` is used
464477
(sufficient for built-ins and Python scalar UDFs, plus any UDFs
465478
registered on the global context).
479+
480+
.. warning::
481+
Decoding may invoke ``cloudpickle.loads`` on bytes embedded
482+
in the payload, which executes arbitrary Python code. Treat
483+
``buf`` as code, not data — only decode bytes you produced
484+
yourself or received from a trusted sender.
485+
486+
Examples:
487+
>>> from datafusion import Expr, col, lit
488+
>>> blob = (col("a") + lit(1)).to_bytes()
489+
>>> Expr.from_bytes(blob).canonical_name()
490+
'a + Int64(1)'
466491
"""
467492
from datafusion.ipc import _resolve_ctx
468493

469494
resolved = _resolve_ctx(ctx)
470495
return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf))
471496

472-
def __reduce__(self) -> tuple:
497+
def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]:
473498
"""Pickle protocol hook.
474499
475500
Lets expressions be shipped to worker processes via
@@ -480,12 +505,32 @@ def __reduce__(self) -> tuple:
480505
is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling
481506
back to the global :class:`SessionContext` if none has been
482507
installed on the worker.
508+
509+
.. warning::
510+
:func:`pickle.loads` on the returned tuple executes
511+
arbitrary Python on the receiver, including any
512+
cloudpickled UDF callable embedded in the payload. Only
513+
unpickle expressions from trusted sources.
514+
515+
Examples:
516+
>>> import pickle
517+
>>> from datafusion import col, lit
518+
>>> e = col("a") * lit(2)
519+
>>> pickle.loads(pickle.dumps(e)).canonical_name()
520+
'a * Int64(2)'
483521
"""
484522
return (Expr._reconstruct, (self.to_bytes(),))
485523

486524
@classmethod
487525
def _reconstruct(cls, proto_bytes: bytes) -> Expr:
488-
"""Internal entry point used by :meth:`__reduce__` on unpickle."""
526+
"""Internal entry point used by :meth:`__reduce__` on unpickle.
527+
528+
Examples:
529+
>>> from datafusion import Expr, col, lit
530+
>>> blob = (col("a") + lit(1)).to_bytes()
531+
>>> Expr._reconstruct(blob).canonical_name()
532+
'a + Int64(1)'
533+
"""
489534
return cls.from_bytes(proto_bytes)
490535

491536
def __richcmp__(self, other: Expr, op: int) -> Expr:

python/datafusion/ipc.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ def set_worker_ctx(ctx: SessionContext) -> None:
6565
initializer or a Ray actor ``__init__``. Idempotent: overwrites any
6666
previous value. Stored in a thread-local slot, so each thread within a
6767
worker may install its own context independently.
68+
69+
Examples:
70+
>>> from datafusion import SessionContext
71+
>>> from datafusion.ipc import set_worker_ctx, get_worker_ctx, clear_worker_ctx
72+
>>> set_worker_ctx(SessionContext())
73+
>>> get_worker_ctx() is not None
74+
True
75+
>>> clear_worker_ctx()
6876
"""
6977
_local.ctx = ctx
7078

@@ -76,13 +84,28 @@ def clear_worker_ctx() -> None:
7684
the global :class:`SessionContext` — adequate for built-ins and Python
7785
scalar UDFs, but anything imported via the FFI capsule protocol must
7886
be registered on the global context to resolve.
87+
88+
Examples:
89+
>>> from datafusion import SessionContext
90+
>>> from datafusion.ipc import set_worker_ctx, clear_worker_ctx, get_worker_ctx
91+
>>> set_worker_ctx(SessionContext())
92+
>>> clear_worker_ctx()
93+
>>> get_worker_ctx() is None
94+
True
7995
"""
8096
if hasattr(_local, "ctx"):
8197
del _local.ctx
8298

8399

84100
def get_worker_ctx() -> SessionContext | None:
85-
"""Return this worker's installed :class:`SessionContext`, or ``None``."""
101+
"""Return this worker's installed :class:`SessionContext`, or ``None``.
102+
103+
Examples:
104+
>>> from datafusion.ipc import get_worker_ctx, clear_worker_ctx
105+
>>> clear_worker_ctx()
106+
>>> get_worker_ctx() is None
107+
True
108+
"""
86109
return getattr(_local, "ctx", None)
87110

88111

@@ -95,6 +118,16 @@ def _resolve_ctx(
95118
Falling back to the global :class:`SessionContext` (instead of a
96119
freshly constructed one) preserves any registrations the user has
97120
installed on it.
121+
122+
Examples:
123+
>>> from datafusion import SessionContext
124+
>>> from datafusion.ipc import _resolve_ctx, clear_worker_ctx
125+
>>> clear_worker_ctx()
126+
>>> isinstance(_resolve_ctx(), SessionContext)
127+
True
128+
>>> ctx = SessionContext()
129+
>>> _resolve_ctx(ctx) is ctx
130+
True
98131
"""
99132
if explicit_ctx is not None:
100133
return explicit_ctx

python/datafusion/user_defined.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,19 @@ def name(self) -> str:
148148
For UDFs imported via the FFI capsule protocol, this is the
149149
name the capsule itself reports — not the ``name`` argument
150150
passed to the constructor (which is ignored on the FFI path).
151+
152+
Examples:
153+
>>> import pyarrow as pa
154+
>>> from datafusion import udf
155+
>>> double = udf(
156+
... lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]),
157+
... [pa.int64()],
158+
... pa.int64(),
159+
... volatility="immutable",
160+
... name="double",
161+
... )
162+
>>> double.name
163+
'double'
151164
"""
152165
return self._udf.name
153166

python/tests/test_pickle_expr.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,33 @@ def fn(arr):
125125
blob = pickle.dumps(e)
126126
decoded = pickle.loads(blob) # noqa: S301
127127
assert decoded.canonical_name() == e.canonical_name()
128+
129+
def test_multi_arg_udf_round_trip(self):
130+
"""Wire format builds synthetic `arg_{i}` fields per input — exercise
131+
with a 2-arg UDF spanning two distinct DataTypes."""
132+
add_scaled = udf(
133+
lambda a, b: pa.array(
134+
[
135+
(x.as_py() or 0) + (y.as_py() or 0.0)
136+
for x, y in zip(a, b, strict=False)
137+
]
138+
),
139+
[pa.int64(), pa.float64()],
140+
pa.float64(),
141+
volatility="immutable",
142+
name="add_scaled",
143+
)
144+
e = add_scaled(col("a"), col("b"))
145+
decoded = pickle.loads(pickle.dumps(e)) # noqa: S301
146+
assert decoded.canonical_name() == e.canonical_name()
147+
assert "add_scaled" in decoded.canonical_name()
148+
149+
150+
class TestErrorPaths:
151+
def test_from_bytes_rejects_garbage(self):
152+
with pytest.raises(Exception): # noqa: B017
153+
Expr.from_bytes(b"not a valid protobuf payload")
154+
155+
def test_from_bytes_rejects_empty(self):
156+
with pytest.raises(Exception): # noqa: B017
157+
Expr.from_bytes(b"")

0 commit comments

Comments
 (0)