Skip to content

Commit d0baeb6

Browse files
timsaucerclaude
andcommitted
feat: inline encoding for Python aggregate and window UDFs
Extends the PythonLogicalCodec / PythonPhysicalCodec inline encoding introduced for scalar UDFs to also cover Python-defined aggregate and window UDFs. The cloudpickle tuple shape per family is: DFPYUDA (agg) (name, accumulator_factory, input_schema_bytes, return_schema_bytes, state_schema_bytes, volatility_str) DFPYUDW (window) (name, evaluator_factory, input_schema_bytes, return_schema_bytes, volatility_str) Same wire-framing as scalar (family magic + version byte + cloudpickle blob), same schema serde (arrow-rs native IPC), same cached cloudpickle handle. The agg state schema is encoded as a full IPC schema so the post-decode UDF reports the same names + nullability + metadata as the sender — relevant for accumulators whose StateFieldsArgs consumers key off names rather than positional DataType. Required restructuring two existing UDF impls so the codec can grab the Python callable directly: * udaf.rs: replaces create_udaf + AccumulatorFactoryFunction closure with a named PythonFunctionAggregateUDF that stores the Py<PyAny> accumulator factory. Synthesizes state_{i} field names when the Python constructor passes only Vec<DataType>; from_parts preserves the full state schema on the decode side. * udwf.rs: renames MultiColumnWindowUDF -> PythonFunctionWindowUDF, drops the PartitionEvaluatorFactory PtrEq wrapper, stores the Py<PyAny> evaluator directly. PartialEq and Hash get the same pointer-identity fast path + debug-log exception handling already on PythonFunctionScalarUDF. User-facing surface: * AggregateUDF.name and WindowUDF.name properties (parallel to the ScalarUDF.name shipped in PR1). * Existing UDAF/UDWF construction paths are unchanged. The per-session with_python_udf_inlining toggle, sender-side context, strict refusal, and user-guide docs land in PRs 3-4 of this series. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 061f3ab commit d0baeb6

7 files changed

Lines changed: 653 additions & 78 deletions

File tree

crates/core/src/codec.rs

Lines changed: 256 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,14 @@
6262
//! | Layer + kind | Family prefix |
6363
//! | ----------------------------- | ------------- |
6464
//! | `PythonLogicalCodec` scalar | `DFPYUDF` |
65+
//! | `PythonLogicalCodec` agg | `DFPYUDA` |
66+
//! | `PythonLogicalCodec` window | `DFPYUDW` |
6567
//! | `PythonPhysicalCodec` scalar | `DFPYUDF` |
68+
//! | `PythonPhysicalCodec` agg | `DFPYUDA` |
69+
//! | `PythonPhysicalCodec` window | `DFPYUDW` |
6670
//! | User FFI extension codec | user-chosen |
6771
//! | Default codec | (none) |
6872
//!
69-
//! Aggregate and window UDF families are reserved for follow-on work.
70-
//!
7173
//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported
7274
//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`.
7375
//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape
@@ -90,8 +92,8 @@ use datafusion::datasource::TableProvider;
9092
use datafusion::datasource::file_format::FileFormatFactory;
9193
use datafusion::execution::TaskContext;
9294
use datafusion::logical_expr::{
93-
AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature,
94-
Volatility, WindowUDF,
95+
AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature,
96+
TypeSignature, Volatility, WindowUDF, WindowUDFImpl,
9597
};
9698
use datafusion::physical_expr::PhysicalExpr;
9799
use datafusion::physical_plan::ExecutionPlan;
@@ -101,7 +103,9 @@ use pyo3::prelude::*;
101103
use pyo3::sync::PyOnceLock;
102104
use pyo3::types::{PyBytes, PyTuple};
103105

106+
use crate::udaf::PythonFunctionAggregateUDF;
104107
use crate::udf::PythonFunctionScalarUDF;
108+
use crate::udwf::PythonFunctionWindowUDF;
105109

106110
// Wire-format framing for inlined Python UDF payloads.
107111
//
@@ -118,6 +122,16 @@ use crate::udf::PythonFunctionScalarUDF;
118122
/// volatility).
119123
pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF";
120124

125+
/// Family prefix for an inlined Python aggregate UDF
126+
/// (cloudpickled tuple of name, accumulator factory, input schema,
127+
/// return type, state types schema, volatility).
128+
pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA";
129+
130+
/// Family prefix for an inlined Python window UDF
131+
/// (cloudpickled tuple of name, evaluator factory, input schema,
132+
/// return type, volatility).
133+
pub(crate) const PY_WINDOW_UDF_FAMILY: &[u8] = b"DFPYUDW";
134+
121135
/// Wire-format version this build emits.
122136
pub(crate) const WIRE_VERSION_CURRENT: u8 = 1;
123137

@@ -260,18 +274,30 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
260274
}
261275

262276
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
277+
if try_encode_python_agg_udf(node, buf)? {
278+
return Ok(());
279+
}
263280
self.inner.try_encode_udaf(node, buf)
264281
}
265282

266283
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
284+
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
285+
return Ok(udaf);
286+
}
267287
self.inner.try_decode_udaf(name, buf)
268288
}
269289

270290
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
291+
if try_encode_python_window_udf(node, buf)? {
292+
return Ok(());
293+
}
271294
self.inner.try_encode_udwf(node, buf)
272295
}
273296

274297
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
298+
if let Some(udwf) = try_decode_python_window_udf(buf)? {
299+
return Ok(udwf);
300+
}
275301
self.inner.try_decode_udwf(name, buf)
276302
}
277303
}
@@ -350,18 +376,30 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
350376
}
351377

352378
fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<()> {
379+
if try_encode_python_agg_udf(node, buf)? {
380+
return Ok(());
381+
}
353382
self.inner.try_encode_udaf(node, buf)
354383
}
355384

356385
fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result<Arc<AggregateUDF>> {
386+
if let Some(udaf) = try_decode_python_agg_udf(buf)? {
387+
return Ok(udaf);
388+
}
357389
self.inner.try_decode_udaf(name, buf)
358390
}
359391

360392
fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec<u8>) -> Result<()> {
393+
if try_encode_python_window_udf(node, buf)? {
394+
return Ok(());
395+
}
361396
self.inner.try_encode_udwf(node, buf)
362397
}
363398

364399
fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result<Arc<WindowUDF>> {
400+
if let Some(udwf) = try_decode_python_window_udf(buf)? {
401+
return Ok(udwf);
402+
}
365403
self.inner.try_decode_udwf(name, buf)
366404
}
367405
}
@@ -525,6 +563,11 @@ fn build_single_field_schema_bytes(field: &Field) -> PyResult<Vec<u8>> {
525563
schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err)
526564
}
527565

566+
/// Emit a multi-field IPC schema blob.
567+
fn build_schema_bytes(fields: Vec<Field>) -> PyResult<Vec<u8>> {
568+
schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err)
569+
}
570+
528571
/// Decode the per-arg `DataType`s the encoder wrote via
529572
/// [`build_input_schema_bytes`].
530573
fn read_input_dtypes(bytes: &[u8]) -> PyResult<Vec<DataType>> {
@@ -589,6 +632,200 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
589632
.map(|cached| cached.bind(py).clone())
590633
}
591634

635+
// =============================================================================
636+
// Shared Python window UDF encode / decode helpers
637+
//
638+
// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes,
639+
// return_schema_bytes, volatility_str)`. The evaluator factory is the
640+
// Python callable that produces a new evaluator instance per partition.
641+
// =============================================================================
642+
643+
pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec<u8>) -> Result<bool> {
644+
let Some(py_udf) = node
645+
.inner()
646+
.as_any()
647+
.downcast_ref::<PythonFunctionWindowUDF>()
648+
else {
649+
return Ok(false);
650+
};
651+
652+
Python::attach(|py| -> Result<bool> {
653+
let bytes = encode_python_window_udf(py, py_udf)
654+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
655+
write_wire_header(buf, PY_WINDOW_UDF_FAMILY);
656+
buf.extend_from_slice(&bytes);
657+
Ok(true)
658+
})
659+
}
660+
661+
pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result<Option<Arc<WindowUDF>>> {
662+
let Some(payload) = strip_wire_header(buf, PY_WINDOW_UDF_FAMILY, "window UDF")? else {
663+
return Ok(None);
664+
};
665+
666+
Python::attach(|py| -> Result<Option<Arc<WindowUDF>>> {
667+
let udf = decode_python_window_udf(py, payload)
668+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
669+
Ok(Some(Arc::new(WindowUDF::new_from_impl(udf))))
670+
})
671+
}
672+
673+
fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult<Vec<u8>> {
674+
let signature = WindowUDFImpl::signature(udf);
675+
let input_dtypes = signature_input_dtypes(signature, "PythonFunctionWindowUDF")?;
676+
let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
677+
let return_field = Field::new("result", udf.return_type().clone(), true);
678+
let return_schema_bytes = build_single_field_schema_bytes(&return_field)?;
679+
let volatility = volatility_wire_str(signature.volatility);
680+
681+
let payload = PyTuple::new(
682+
py,
683+
[
684+
WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(),
685+
udf.evaluator().bind(py).clone().into_any(),
686+
PyBytes::new(py, &input_schema_bytes).into_any(),
687+
PyBytes::new(py, &return_schema_bytes).into_any(),
688+
volatility.into_pyobject(py)?.into_any(),
689+
],
690+
)?;
691+
692+
cloudpickle(py)?
693+
.call_method1("dumps", (payload,))?
694+
.extract::<Vec<u8>>()
695+
}
696+
697+
fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionWindowUDF> {
698+
let tuple = cloudpickle(py)?
699+
.call_method1("loads", (PyBytes::new(py, payload),))?
700+
.cast_into::<PyTuple>()?;
701+
702+
let name: String = tuple.get_item(0)?.extract()?;
703+
let evaluator: Py<PyAny> = tuple.get_item(1)?.unbind();
704+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
705+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
706+
let volatility_str: String = tuple.get_item(4)?.extract()?;
707+
708+
let input_types = read_input_dtypes(&input_schema_bytes)?;
709+
let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionWindowUDF")?
710+
.data_type()
711+
.clone();
712+
let volatility = parse_volatility_str(&volatility_str)?;
713+
714+
Ok(PythonFunctionWindowUDF::new(
715+
name,
716+
evaluator,
717+
input_types,
718+
return_type,
719+
volatility,
720+
))
721+
}
722+
723+
// =============================================================================
724+
// Shared Python aggregate UDF encode / decode helpers
725+
//
726+
// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
727+
// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
728+
// factory is the Python callable that produces a new accumulator instance
729+
// per partition.
730+
// =============================================================================
731+
732+
pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec<u8>) -> Result<bool> {
733+
let Some(py_udf) = node
734+
.inner()
735+
.as_any()
736+
.downcast_ref::<PythonFunctionAggregateUDF>()
737+
else {
738+
return Ok(false);
739+
};
740+
741+
Python::attach(|py| -> Result<bool> {
742+
let bytes = encode_python_agg_udf(py, py_udf)
743+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
744+
write_wire_header(buf, PY_AGG_UDF_FAMILY);
745+
buf.extend_from_slice(&bytes);
746+
Ok(true)
747+
})
748+
}
749+
750+
pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result<Option<Arc<AggregateUDF>>> {
751+
let Some(payload) = strip_wire_header(buf, PY_AGG_UDF_FAMILY, "aggregate UDF")? else {
752+
return Ok(None);
753+
};
754+
755+
Python::attach(|py| -> Result<Option<Arc<AggregateUDF>>> {
756+
let udf = decode_python_agg_udf(py, payload)
757+
.map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?;
758+
Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf))))
759+
})
760+
}
761+
762+
fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult<Vec<u8>> {
763+
let signature = AggregateUDFImpl::signature(udf);
764+
let input_dtypes = signature_input_dtypes(signature, "PythonFunctionAggregateUDF")?;
765+
let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?;
766+
let return_field = Field::new("result", udf.return_type().clone(), true);
767+
let return_schema_bytes = build_single_field_schema_bytes(&return_field)?;
768+
let state_fields: Vec<Field> = udf
769+
.state_fields_ref()
770+
.iter()
771+
.map(|f| f.as_ref().clone())
772+
.collect();
773+
let state_schema_bytes = build_schema_bytes(state_fields)?;
774+
let volatility = volatility_wire_str(signature.volatility);
775+
776+
let payload = PyTuple::new(
777+
py,
778+
[
779+
AggregateUDFImpl::name(udf).into_pyobject(py)?.into_any(),
780+
udf.accumulator().bind(py).clone().into_any(),
781+
PyBytes::new(py, &input_schema_bytes).into_any(),
782+
PyBytes::new(py, &return_schema_bytes).into_any(),
783+
PyBytes::new(py, &state_schema_bytes).into_any(),
784+
volatility.into_pyobject(py)?.into_any(),
785+
],
786+
)?;
787+
788+
cloudpickle(py)?
789+
.call_method1("dumps", (payload,))?
790+
.extract::<Vec<u8>>()
791+
}
792+
793+
fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult<PythonFunctionAggregateUDF> {
794+
let tuple = cloudpickle(py)?
795+
.call_method1("loads", (PyBytes::new(py, payload),))?
796+
.cast_into::<PyTuple>()?;
797+
798+
let name: String = tuple.get_item(0)?.extract()?;
799+
let accumulator: Py<PyAny> = tuple.get_item(1)?.unbind();
800+
let input_schema_bytes: Vec<u8> = tuple.get_item(2)?.extract()?;
801+
let return_schema_bytes: Vec<u8> = tuple.get_item(3)?.extract()?;
802+
let state_schema_bytes: Vec<u8> = tuple.get_item(4)?.extract()?;
803+
let volatility_str: String = tuple.get_item(5)?.extract()?;
804+
805+
let input_types = read_input_dtypes(&input_schema_bytes)?;
806+
let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionAggregateUDF")?
807+
.data_type()
808+
.clone();
809+
// Preserve the encoded state field metadata (names, nullability,
810+
// arbitrary key/value attributes) so the post-decode UDF reports
811+
// the same state schema as the sender's instance — important for
812+
// accumulators whose `StateFieldsArgs` consumers key off names or
813+
// nullability rather than positional `DataType`.
814+
let state_schema = schema_from_ipc_bytes(&state_schema_bytes).map_err(arrow_to_py_err)?;
815+
let state_fields: Vec<arrow::datatypes::FieldRef> =
816+
state_schema.fields().iter().cloned().collect();
817+
let volatility = parse_volatility_str(&volatility_str)?;
818+
819+
Ok(PythonFunctionAggregateUDF::from_parts(
820+
name,
821+
accumulator,
822+
input_types,
823+
return_type,
824+
state_fields,
825+
volatility,
826+
))
827+
}
828+
592829
#[cfg(test)]
593830
mod wire_header_tests {
594831
use super::*;
@@ -635,12 +872,23 @@ mod wire_header_tests {
635872
#[test]
636873
fn write_then_strip_round_trips_payload() {
637874
let mut buf = Vec::new();
638-
write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY);
639-
buf.extend_from_slice(b"scalar-payload");
875+
write_wire_header(&mut buf, PY_AGG_UDF_FAMILY);
876+
buf.extend_from_slice(b"agg-payload");
640877

641-
let payload = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")
878+
let payload = strip_wire_header(&buf, PY_AGG_UDF_FAMILY, "aggregate UDF")
642879
.unwrap()
643880
.unwrap();
644-
assert_eq!(payload, b"scalar-payload");
881+
assert_eq!(payload, b"agg-payload");
882+
}
883+
884+
#[test]
885+
fn strip_does_not_match_a_different_family() {
886+
let mut buf = Vec::new();
887+
write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY);
888+
buf.extend_from_slice(b"payload");
889+
assert!(matches!(
890+
strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF"),
891+
Ok(None)
892+
));
645893
}
646894
}

0 commit comments

Comments
 (0)