Skip to content

Commit 28da71a

Browse files
timsaucerclaude
andcommitted
fix: address PR #1545 review feedback
- Fix CountAcc.merge in pickle test: sum over states[0] (partition counts), not over the list of state fields. The prior implementation only added partition 0's count when merging across partitions. - Drive test_agg_udf_evaluates_after_roundtrip with a two-batch DataFrame so merge actually runs and the round-tripped state-field schema is exercised end-to-end. - Correct PY_AGG_UDF_FAMILY / PY_WINDOW_UDF_FAMILY doc comments and the aggregate block comment to reference "return schema bytes" rather than "return type" / "return_type_bytes" so the docs match the actual on-wire layout. - Keep `udaf` and `udwf` modules private (matching `udf`) and selectively re-export the helpers downstream Rust consumers rely on (`to_rust_accumulator`, `to_rust_partition_evaluator`, `PythonFunctionWindowUDF`, `MultiColumnWindowUDF`) instead of exposing the whole module surface. - Rename codec helpers `*_agg_udf` -> `*_udaf` and `*_window_udf` -> `*_udwf` for naming consistency with the Python public aliases. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4a3237f commit 28da71a

3 files changed

Lines changed: 40 additions & 30 deletions

File tree

crates/core/src/codec.rs

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,14 @@ use crate::udwf::PythonFunctionWindowUDF;
132132
pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF";
133133

134134
/// Family prefix for an inlined Python aggregate UDF
135-
/// (cloudpickled tuple of name, accumulator factory, input schema,
136-
/// return type, state types schema, volatility).
135+
/// (cloudpickled tuple of name, accumulator factory, input schema bytes,
136+
/// return schema bytes (single-field IPC schema), state schema bytes,
137+
/// volatility).
137138
pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA";
138139

139140
/// Family prefix for an inlined Python window UDF
140-
/// (cloudpickled tuple of name, evaluator factory, input schema,
141-
/// return type, volatility).
141+
/// (cloudpickled tuple of name, evaluator factory, input schema bytes,
142+
/// return schema bytes (single-field IPC schema), volatility).
142143
pub(crate) const PY_WINDOW_UDF_FAMILY: &[u8] = b"DFPYUDW";
143144

144145
/// Wire-format version this build emits.
@@ -314,28 +315,28 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
314315
}
315316

316317
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
317-
if try_encode_python_agg_udf(node, buf)? {
318+
if try_encode_python_udaf(node, buf)? {
318319
return Ok(());
319320
}
320321
self.inner.try_encode_udaf(node, buf)
321322
}
322323

323324
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
324-
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
325+
if let Some(udaf) = try_decode_python_udaf(buf)? {
325326
return Ok(udaf);
326327
}
327328
self.inner.try_decode_udaf(name, buf)
328329
}
329330

330331
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
331-
if try_encode_python_window_udf(node, buf)? {
332+
if try_encode_python_udwf(node, buf)? {
332333
return Ok(());
333334
}
334335
self.inner.try_encode_udwf(node, buf)
335336
}
336337

337338
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
338-
if let Some(udwf) = try_decode_python_window_udf(buf)? {
339+
if let Some(udwf) = try_decode_python_udwf(buf)? {
339340
return Ok(udwf);
340341
}
341342
self.inner.try_decode_udwf(name, buf)
@@ -416,28 +417,28 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
416417
}
417418

418419
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
419-
if try_encode_python_agg_udf(node, buf)? {
420+
if try_encode_python_udaf(node, buf)? {
420421
return Ok(());
421422
}
422423
self.inner.try_encode_udaf(node, buf)
423424
}
424425

425426
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
426-
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
427+
if let Some(udaf) = try_decode_python_udaf(buf)? {
427428
return Ok(udaf);
428429
}
429430
self.inner.try_decode_udaf(name, buf)
430431
}
431432

432433
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
433-
if try_encode_python_window_udf(node, buf)? {
434+
if try_encode_python_udwf(node, buf)? {
434435
return Ok(());
435436
}
436437
self.inner.try_encode_udwf(node, buf)
437438
}
438439

439440
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
440-
if let Some(udwf) = try_decode_python_window_udf(buf)? {
441+
if let Some(udwf) = try_decode_python_udwf(buf)? {
441442
return Ok(udwf);
442443
}
443444
self.inner.try_decode_udwf(name, buf)
@@ -718,30 +719,30 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
718719
// Python callable that produces a new evaluator instance per partition.
719720
// =============================================================================
720721

721-
pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec<u8>) -> Result<bool> {
722+
pub(crate) fn try_encode_python_udwf(node: &WindowUDF, buf: &mut Vec<u8>) -> Result<bool> {
722723
let Some(py_udf) = node.inner().downcast_ref::<PythonFunctionWindowUDF>() else {
723724
return Ok(false);
724725
};
725726

726727
Python::attach(|py| -> Result<bool> {
727-
let bytes = encode_python_window_udf(py, py_udf).map_err(to_datafusion_err)?;
728+
let bytes = encode_python_udwf(py, py_udf).map_err(to_datafusion_err)?;
728729
append_framed_payload(py, buf, PY_WINDOW_UDF_FAMILY, &bytes)?;
729730
Ok(true)
730731
})
731732
}
732733

733-
pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result<Option<Arc<WindowUDF>>> {
734+
pub(crate) fn try_decode_python_udwf(buf: &[u8]) -> Result<Option<Arc<WindowUDF>>> {
734735
Python::attach(|py| -> Result<Option<Arc<WindowUDF>>> {
735736
let Some(payload) = read_framed_payload(py, buf, PY_WINDOW_UDF_FAMILY, "window UDF")?
736737
else {
737738
return Ok(None);
738739
};
739-
let udf = decode_python_window_udf(py, payload).map_err(to_datafusion_err)?;
740+
let udf = decode_python_udwf(py, payload).map_err(to_datafusion_err)?;
740741
Ok(Some(Arc::new(WindowUDF::new_from_impl(udf))))
741742
})
742743
}
743744

744-
fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult<Vec<u8>> {
745+
fn encode_python_udwf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult<Vec<u8>> {
745746
let signature = WindowUDFImpl::signature(udf);
746747
let input_dtypes = signature_input_dtypes(signature, "PythonFunctionWindowUDF")?;
747748
let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
@@ -765,7 +766,7 @@ fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> Py
765766
.extract::<Vec<u8>>()
766767
}
767768

768-
fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionWindowUDF> {
769+
fn decode_python_udwf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionWindowUDF> {
769770
let tuple = cloudpickle(py)?
770771
.call_method1("loads", (PyBytes::new(py, payload),))?
771772
.cast_into::<PyTuple>()?;
@@ -795,35 +796,35 @@ fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFu
795796
// Shared Python aggregate UDF encode / decode helpers
796797
//
797798
// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
798-
// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
799+
// return_schema_bytes, state_schema_bytes, volatility_str)`. The accumulator
799800
// factory is the Python callable that produces a new accumulator instance
800801
// per partition.
801802
// =============================================================================
802803

803-
pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<bool> {
804+
pub(crate) fn try_encode_python_udaf(node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<bool> {
804805
let Some(py_udf) = node.inner().downcast_ref::<PythonFunctionAggregateUDF>() else {
805806
return Ok(false);
806807
};
807808

808809
Python::attach(|py| -> Result<bool> {
809-
let bytes = encode_python_agg_udf(py, py_udf).map_err(to_datafusion_err)?;
810+
let bytes = encode_python_udaf(py, py_udf).map_err(to_datafusion_err)?;
810811
append_framed_payload(py, buf, PY_AGG_UDF_FAMILY, &bytes)?;
811812
Ok(true)
812813
})
813814
}
814815

815-
pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result<Option<Arc<AggregateUDF>>> {
816+
pub(crate) fn try_decode_python_udaf(buf: &[u8]) -> Result<Option<Arc<AggregateUDF>>> {
816817
Python::attach(|py| -> Result<Option<Arc<AggregateUDF>>> {
817818
let Some(payload) = read_framed_payload(py, buf, PY_AGG_UDF_FAMILY, "aggregate UDF")?
818819
else {
819820
return Ok(None);
820821
};
821-
let udf = decode_python_agg_udf(py, payload).map_err(to_datafusion_err)?;
822+
let udf = decode_python_udaf(py, payload).map_err(to_datafusion_err)?;
822823
Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf))))
823824
})
824825
}
825826

826-
fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult<Vec<u8>> {
827+
fn encode_python_udaf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult<Vec<u8>> {
827828
let signature = AggregateUDFImpl::signature(udf);
828829
let input_dtypes = signature_input_dtypes(signature, "PythonFunctionAggregateUDF")?;
829830
let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
@@ -854,7 +855,7 @@ fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> Py
854855
.extract::<Vec<u8>>()
855856
}
856857

857-
fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionAggregateUDF> {
858+
fn decode_python_udaf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionAggregateUDF> {
858859
let tuple = cloudpickle(py)?
859860
.call_method1("loads", (PyBytes::new(py, payload),))?
860861
.cast_into::<PyTuple>()?;

crates/core/src/lib.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,16 @@ mod array;
5959
#[cfg(feature = "substrait")]
6060
pub mod substrait;
6161
#[allow(clippy::borrow_deref_ref)]
62-
pub mod udaf;
62+
mod udaf;
6363
#[allow(clippy::borrow_deref_ref)]
6464
mod udf;
6565
pub mod udtf;
66-
pub mod udwf;
66+
mod udwf;
67+
68+
// Re-export helpers previously consumed by downstream Rust crates.
69+
// Modules stay private to keep the public Rust API surface small.
70+
pub use udaf::to_rust_accumulator;
71+
pub use udwf::{MultiColumnWindowUDF, PythonFunctionWindowUDF, to_rust_partition_evaluator};
6772

6873
#[cfg(feature = "mimalloc")]
6974
#[global_allocator]

python/tests/test_pickle_expr.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ def update(self, values):
165165
self._count += len(values)
166166

167167
def merge(self, states):
168-
for s in states:
169-
self._count += s[0].as_py()
168+
partition_counts = states[0]
169+
for i in range(len(partition_counts)):
170+
self._count += partition_counts[i].as_py()
170171

171172
def evaluate(self):
172173
return pa.scalar(self._count, type=pa.int64())
@@ -209,7 +210,10 @@ def test_agg_udf_evaluates_after_roundtrip(self):
209210
decoded = pickle.loads(pickle.dumps(e)) # noqa: S301
210211

211212
ctx = SessionContext()
212-
df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]})
213+
schema = pa.schema([pa.field("a", pa.int64())])
214+
batch1 = pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], schema=schema)
215+
batch2 = pa.record_batch([pa.array([4, 5], type=pa.int64())], schema=schema)
216+
df = ctx.create_dataframe([[batch1], [batch2]])
213217
out = df.aggregate([], [decoded.alias("n")]).to_pydict()
214218
assert out["n"] == [5]
215219

0 commit comments

Comments
 (0)