Skip to content

Commit d5055e1

Browse files
timsaucerclaude
andcommitted
docs(example): multiprocessing pool with distinct expressions per worker
The Ray example fans the same expression out to every actor — it demonstrates pickling works but doesn't show why expression pickling is necessary (a by-name registration on each worker would suffice for that pattern). This example builds a list of parametric expressions in the driver, each closing over a different threshold value, and ships one per worker via `multiprocessing.Pool`. The closure state forces the cloudpickle path: a by-name registration on the worker would collapse every threshold into the same callable and lose the per-task value. Workers return results plus their PID so the driver output makes the cross-process distribution visible. Also mixes a scalar UDF in an aggregate and a pure UDAF so both kinds round-trip through pickle. README and the distributing-work "See also" section now link the example. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1300076 commit d5055e1

3 files changed

Lines changed: 169 additions & 0 deletions

File tree

docs/source/user-guide/io/distributing_work.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,7 @@ See also
354354
--------
355355

356356
* :py:mod:`datafusion.ipc` — worker context API.
357+
* ``examples/multiprocessing_pickle_expr.py`` — runnable
358+
``multiprocessing.Pool`` example that ships a different parametric
359+
expression to each worker and collects results back.
357360
* ``examples/ray_pickle_expr.py`` — runnable Ray actor example.

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Here is a direct link to the file used in the examples:
4646

4747
### Distributing DataFusion expressions
4848

49+
- [Fan out distinct expressions to a multiprocessing pool](./multiprocessing_pickle_expr.py)
4950
- [Distribute expression evaluation across Ray actors](./ray_pickle_expr.py)
5051

5152
### Substrait Support
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Distribute different DataFusion expressions to worker processes.
19+
20+
Builds a list of parametric expressions in the driver — each closing
21+
over a different threshold value — ships one per worker via
22+
``multiprocessing.Pool``, and collects the results back. The closure
23+
state forces the cloudpickle path (a by-name registration would lose
24+
the captured threshold), so this is a real test of the expression-
25+
pickling story rather than a same-expression fan-out.
26+
27+
Worker layout:
28+
29+
* Each worker receives a different ``(expr, label)`` task.
30+
* Each worker materializes the shared dataset locally and runs its
31+
own expression against it.
32+
* The result and the worker's PID travel back to the driver, so the
33+
output makes it visible that the work was spread across processes.
34+
35+
Run:
36+
python examples/multiprocessing_pickle_expr.py
37+
"""
38+
39+
from __future__ import annotations
40+
41+
import multiprocessing as mp
42+
import os
43+
44+
import pyarrow as pa
45+
from datafusion import Expr, SessionContext, col, udaf, udf
46+
from datafusion import functions as F
47+
from datafusion.user_defined import Accumulator, AggregateUDF, ScalarUDF
48+
49+
# A shared input dataset. In a production pipeline this would live on
50+
# object storage; here we hand-roll a small batch so the example runs
51+
# without any I/O setup.
52+
DATASET = {
53+
"value": [3, 17, 42, 5, 88, 21, 9, 56, 4, 73, 12, 31],
54+
}
55+
56+
57+
def make_above_threshold_udf(threshold: int) -> ScalarUDF:
58+
"""Build a scalar UDF that returns 1 where ``value > threshold`` else 0.
59+
60+
The threshold is captured in the closure, so cloudpickle has to
61+
walk into the function body to ship the value across processes —
62+
a by-name registration on the worker would collapse every
63+
threshold into the same callable and lose the per-task state.
64+
"""
65+
66+
def above(arr: pa.Array) -> pa.Array:
67+
return pa.array([1 if (v.as_py() or 0) > threshold else 0 for v in arr])
68+
69+
return udf(
70+
above,
71+
[pa.int64()],
72+
pa.int64(),
73+
volatility="immutable",
74+
name=f"above_{threshold}",
75+
)
76+
77+
78+
class _SumAccumulator(Accumulator):
79+
"""Tiny aggregate UDF state used to demonstrate UDAFs travel too."""
80+
81+
def __init__(self) -> None:
82+
self._total = 0
83+
84+
def state(self) -> list[pa.Scalar]:
85+
return [pa.scalar(self._total, type=pa.int64())]
86+
87+
def update(self, values: pa.Array) -> None:
88+
for v in values:
89+
self._total += v.as_py() or 0
90+
91+
def merge(self, states: list[pa.Array]) -> None:
92+
for s in states:
93+
self._total += s[0].as_py()
94+
95+
def evaluate(self) -> pa.Scalar:
96+
return pa.scalar(self._total, type=pa.int64())
97+
98+
99+
def _build_sum_udaf() -> AggregateUDF:
100+
return udaf(
101+
_SumAccumulator,
102+
[pa.int64()],
103+
pa.int64(),
104+
[pa.int64()],
105+
"immutable",
106+
name="my_sum",
107+
)
108+
109+
110+
def evaluate_in_worker(task: tuple[str, Expr]) -> tuple[str, int, int]:
111+
"""Run one expression against the shared dataset.
112+
113+
``task`` arrived here via the pool's automatic pickling. The Python
114+
callable inside the expression (including its captured threshold)
115+
was reconstructed by the codec — the worker did not have to
116+
register anything before this call.
117+
"""
118+
label, expr = task
119+
ctx = SessionContext()
120+
df = ctx.from_pydict(DATASET)
121+
# ``expr`` is an aggregate over the whole batch; ``aggregate`` keeps
122+
# a single row of output, which we read as a Python int.
123+
result_df = df.aggregate([], [expr.alias("result")])
124+
result = result_df.to_pydict()["result"][0]
125+
return label, result, os.getpid()
126+
127+
128+
def build_tasks() -> list[tuple[str, Expr]]:
129+
"""Return ``(label, expr)`` pairs — one task per worker invocation.
130+
131+
Mixes scalar-UDF-in-aggregate and pure-aggregate work to show both
132+
UDF kinds round-tripping through pickle.
133+
"""
134+
sum_udaf = _build_sum_udaf()
135+
tasks: list[tuple[str, Expr]] = []
136+
137+
# Three "count values strictly above threshold T" tasks built from
138+
# closure-capturing scalar UDFs.
139+
for threshold in (10, 30, 60):
140+
above_udf = make_above_threshold_udf(threshold)
141+
tasks.append((f"count_above_{threshold}", F.sum(above_udf(col("value")))))
142+
143+
# One pure aggregate UDF task.
144+
tasks.append(("custom_sum", sum_udaf(col("value"))))
145+
146+
return tasks
147+
148+
149+
def main() -> None:
150+
tasks = build_tasks()
151+
152+
# ``forkserver`` works on every POSIX platform and is the Python 3.14
153+
# default for POSIX. ``spawn`` would also work; ``fork`` is unsafe
154+
# with pyarrow/tokio on macOS.
155+
mp_ctx = mp.get_context("forkserver")
156+
with mp_ctx.Pool(processes=min(4, len(tasks))) as pool:
157+
results = pool.map(evaluate_in_worker, tasks)
158+
159+
print(f"driver pid: {os.getpid()}")
160+
for label, value, worker_pid in results:
161+
print(f" [{label:>16}] = {value:>6} (worker pid: {worker_pid})")
162+
163+
164+
if __name__ == "__main__":
165+
main()

0 commit comments

Comments
 (0)