6262//! | Layer + kind | Family prefix |
6363//! | ----------------------------- | ------------- |
6464//! | `PythonLogicalCodec` scalar | `DFPYUDF` |
65+ //! | `PythonLogicalCodec` agg | `DFPYUDA` |
66+ //! | `PythonLogicalCodec` window | `DFPYUDW` |
6567//! | `PythonPhysicalCodec` scalar | `DFPYUDF` |
68+ //! | `PythonPhysicalCodec` agg | `DFPYUDA` |
69+ //! | `PythonPhysicalCodec` window | `DFPYUDW` |
6670//! | User FFI extension codec | user-chosen |
6771//! | Default codec | (none) |
6872//!
69- //! Aggregate and window UDF families are reserved for follow-on work.
70- //!
7173//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported
7274//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`.
7375//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape
@@ -90,8 +92,8 @@ use datafusion::datasource::TableProvider;
9092use datafusion:: datasource:: file_format:: FileFormatFactory ;
9193use datafusion:: execution:: TaskContext ;
9294use datafusion:: logical_expr:: {
93- AggregateUDF , Extension , LogicalPlan , ScalarUDF , ScalarUDFImpl , Signature , TypeSignature ,
94- Volatility , WindowUDF ,
95+ AggregateUDF , AggregateUDFImpl , Extension , LogicalPlan , ScalarUDF , ScalarUDFImpl , Signature ,
96+ TypeSignature , Volatility , WindowUDF , WindowUDFImpl ,
9597} ;
9698use datafusion:: physical_expr:: PhysicalExpr ;
9799use datafusion:: physical_plan:: ExecutionPlan ;
@@ -101,7 +103,9 @@ use pyo3::prelude::*;
101103use pyo3:: sync:: PyOnceLock ;
102104use pyo3:: types:: { PyBytes , PyTuple } ;
103105
106+ use crate :: udaf:: PythonFunctionAggregateUDF ;
104107use crate :: udf:: PythonFunctionScalarUDF ;
108+ use crate :: udwf:: PythonFunctionWindowUDF ;
105109
106110// Wire-format framing for inlined Python UDF payloads.
107111//
@@ -118,6 +122,16 @@ use crate::udf::PythonFunctionScalarUDF;
118122/// volatility).
119123pub ( crate ) const PY_SCALAR_UDF_FAMILY : & [ u8 ] = b"DFPYUDF" ;
120124
125+ /// Family prefix for an inlined Python aggregate UDF
126+ /// (cloudpickled tuple of name, accumulator factory, input schema,
127+ /// return type, state types schema, volatility).
128+ pub ( crate ) const PY_AGG_UDF_FAMILY : & [ u8 ] = b"DFPYUDA" ;
129+
130+ /// Family prefix for an inlined Python window UDF
131+ /// (cloudpickled tuple of name, evaluator factory, input schema,
132+ /// return type, volatility).
133+ pub ( crate ) const PY_WINDOW_UDF_FAMILY : & [ u8 ] = b"DFPYUDW" ;
134+
121135/// Wire-format version this build emits.
122136pub ( crate ) const WIRE_VERSION_CURRENT : u8 = 1 ;
123137
@@ -260,18 +274,30 @@ impl LogicalExtensionCodec for PythonLogicalCodec {
260274 }
261275
262276 fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
277+ if try_encode_python_agg_udf ( node, buf) ? {
278+ return Ok ( ( ) ) ;
279+ }
263280 self . inner . try_encode_udaf ( node, buf)
264281 }
265282
266283 fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
284+ if let Some ( udaf) = try_decode_python_agg_udf ( buf) ? {
285+ return Ok ( udaf) ;
286+ }
267287 self . inner . try_decode_udaf ( name, buf)
268288 }
269289
270290 fn try_encode_udwf ( & self , node : & WindowUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
291+ if try_encode_python_window_udf ( node, buf) ? {
292+ return Ok ( ( ) ) ;
293+ }
271294 self . inner . try_encode_udwf ( node, buf)
272295 }
273296
274297 fn try_decode_udwf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < WindowUDF > > {
298+ if let Some ( udwf) = try_decode_python_window_udf ( buf) ? {
299+ return Ok ( udwf) ;
300+ }
275301 self . inner . try_decode_udwf ( name, buf)
276302 }
277303}
@@ -350,18 +376,30 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec {
350376 }
351377
352378 fn try_encode_udaf ( & self , node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
379+ if try_encode_python_agg_udf ( node, buf) ? {
380+ return Ok ( ( ) ) ;
381+ }
353382 self . inner . try_encode_udaf ( node, buf)
354383 }
355384
356385 fn try_decode_udaf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < AggregateUDF > > {
386+ if let Some ( udaf) = try_decode_python_agg_udf ( buf) ? {
387+ return Ok ( udaf) ;
388+ }
357389 self . inner . try_decode_udaf ( name, buf)
358390 }
359391
360392 fn try_encode_udwf ( & self , node : & WindowUDF , buf : & mut Vec < u8 > ) -> Result < ( ) > {
393+ if try_encode_python_window_udf ( node, buf) ? {
394+ return Ok ( ( ) ) ;
395+ }
361396 self . inner . try_encode_udwf ( node, buf)
362397 }
363398
364399 fn try_decode_udwf ( & self , name : & str , buf : & [ u8 ] ) -> Result < Arc < WindowUDF > > {
400+ if let Some ( udwf) = try_decode_python_window_udf ( buf) ? {
401+ return Ok ( udwf) ;
402+ }
365403 self . inner . try_decode_udwf ( name, buf)
366404 }
367405}
@@ -525,6 +563,11 @@ fn build_single_field_schema_bytes(field: &Field) -> PyResult<Vec<u8>> {
525563 schema_to_ipc_bytes ( & Schema :: new ( vec ! [ field. clone( ) ] ) ) . map_err ( arrow_to_py_err)
526564}
527565
566+ /// Emit a multi-field IPC schema blob.
567+ fn build_schema_bytes ( fields : Vec < Field > ) -> PyResult < Vec < u8 > > {
568+ schema_to_ipc_bytes ( & Schema :: new ( fields) ) . map_err ( arrow_to_py_err)
569+ }
570+
528571/// Decode the per-arg `DataType`s the encoder wrote via
529572/// [`build_input_schema_bytes`].
530573fn read_input_dtypes ( bytes : & [ u8 ] ) -> PyResult < Vec < DataType > > {
@@ -589,6 +632,200 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
589632 . map ( |cached| cached. bind ( py) . clone ( ) )
590633}
591634
635+ // =============================================================================
636+ // Shared Python window UDF encode / decode helpers
637+ //
638+ // Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes,
639+ // return_schema_bytes, volatility_str)`. The evaluator factory is the
640+ // Python callable that produces a new evaluator instance per partition.
641+ // =============================================================================
642+
643+ pub ( crate ) fn try_encode_python_window_udf ( node : & WindowUDF , buf : & mut Vec < u8 > ) -> Result < bool > {
644+ let Some ( py_udf) = node
645+ . inner ( )
646+ . as_any ( )
647+ . downcast_ref :: < PythonFunctionWindowUDF > ( )
648+ else {
649+ return Ok ( false ) ;
650+ } ;
651+
652+ Python :: attach ( |py| -> Result < bool > {
653+ let bytes = encode_python_window_udf ( py, py_udf)
654+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
655+ write_wire_header ( buf, PY_WINDOW_UDF_FAMILY ) ;
656+ buf. extend_from_slice ( & bytes) ;
657+ Ok ( true )
658+ } )
659+ }
660+
661+ pub ( crate ) fn try_decode_python_window_udf ( buf : & [ u8 ] ) -> Result < Option < Arc < WindowUDF > > > {
662+ let Some ( payload) = strip_wire_header ( buf, PY_WINDOW_UDF_FAMILY , "window UDF" ) ? else {
663+ return Ok ( None ) ;
664+ } ;
665+
666+ Python :: attach ( |py| -> Result < Option < Arc < WindowUDF > > > {
667+ let udf = decode_python_window_udf ( py, payload)
668+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
669+ Ok ( Some ( Arc :: new ( WindowUDF :: new_from_impl ( udf) ) ) )
670+ } )
671+ }
672+
673+ fn encode_python_window_udf ( py : Python < ' _ > , udf : & PythonFunctionWindowUDF ) -> PyResult < Vec < u8 > > {
674+ let signature = WindowUDFImpl :: signature ( udf) ;
675+ let input_dtypes = signature_input_dtypes ( signature, "PythonFunctionWindowUDF" ) ?;
676+ let input_schema_bytes = build_input_schema_bytes ( & input_dtypes) ?;
677+ let return_field = Field :: new ( "result" , udf. return_type ( ) . clone ( ) , true ) ;
678+ let return_schema_bytes = build_single_field_schema_bytes ( & return_field) ?;
679+ let volatility = volatility_wire_str ( signature. volatility ) ;
680+
681+ let payload = PyTuple :: new (
682+ py,
683+ [
684+ WindowUDFImpl :: name ( udf) . into_pyobject ( py) ?. into_any ( ) ,
685+ udf. evaluator ( ) . bind ( py) . clone ( ) . into_any ( ) ,
686+ PyBytes :: new ( py, & input_schema_bytes) . into_any ( ) ,
687+ PyBytes :: new ( py, & return_schema_bytes) . into_any ( ) ,
688+ volatility. into_pyobject ( py) ?. into_any ( ) ,
689+ ] ,
690+ ) ?;
691+
692+ cloudpickle ( py) ?
693+ . call_method1 ( "dumps" , ( payload, ) ) ?
694+ . extract :: < Vec < u8 > > ( )
695+ }
696+
697+ fn decode_python_window_udf ( py : Python < ' _ > , payload : & [ u8 ] ) -> PyResult < PythonFunctionWindowUDF > {
698+ let tuple = cloudpickle ( py) ?
699+ . call_method1 ( "loads" , ( PyBytes :: new ( py, payload) , ) ) ?
700+ . cast_into :: < PyTuple > ( ) ?;
701+
702+ let name: String = tuple. get_item ( 0 ) ?. extract ( ) ?;
703+ let evaluator: Py < PyAny > = tuple. get_item ( 1 ) ?. unbind ( ) ;
704+ let input_schema_bytes: Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
705+ let return_schema_bytes: Vec < u8 > = tuple. get_item ( 3 ) ?. extract ( ) ?;
706+ let volatility_str: String = tuple. get_item ( 4 ) ?. extract ( ) ?;
707+
708+ let input_types = read_input_dtypes ( & input_schema_bytes) ?;
709+ let return_type = read_single_return_field ( & return_schema_bytes, "PythonFunctionWindowUDF" ) ?
710+ . data_type ( )
711+ . clone ( ) ;
712+ let volatility = parse_volatility_str ( & volatility_str) ?;
713+
714+ Ok ( PythonFunctionWindowUDF :: new (
715+ name,
716+ evaluator,
717+ input_types,
718+ return_type,
719+ volatility,
720+ ) )
721+ }
722+
723+ // =============================================================================
724+ // Shared Python aggregate UDF encode / decode helpers
725+ //
726+ // Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes,
727+ // return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator
728+ // factory is the Python callable that produces a new accumulator instance
729+ // per partition.
730+ // =============================================================================
731+
732+ pub ( crate ) fn try_encode_python_agg_udf ( node : & AggregateUDF , buf : & mut Vec < u8 > ) -> Result < bool > {
733+ let Some ( py_udf) = node
734+ . inner ( )
735+ . as_any ( )
736+ . downcast_ref :: < PythonFunctionAggregateUDF > ( )
737+ else {
738+ return Ok ( false ) ;
739+ } ;
740+
741+ Python :: attach ( |py| -> Result < bool > {
742+ let bytes = encode_python_agg_udf ( py, py_udf)
743+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
744+ write_wire_header ( buf, PY_AGG_UDF_FAMILY ) ;
745+ buf. extend_from_slice ( & bytes) ;
746+ Ok ( true )
747+ } )
748+ }
749+
750+ pub ( crate ) fn try_decode_python_agg_udf ( buf : & [ u8 ] ) -> Result < Option < Arc < AggregateUDF > > > {
751+ let Some ( payload) = strip_wire_header ( buf, PY_AGG_UDF_FAMILY , "aggregate UDF" ) ? else {
752+ return Ok ( None ) ;
753+ } ;
754+
755+ Python :: attach ( |py| -> Result < Option < Arc < AggregateUDF > > > {
756+ let udf = decode_python_agg_udf ( py, payload)
757+ . map_err ( |e| datafusion:: error:: DataFusionError :: External ( Box :: new ( e) ) ) ?;
758+ Ok ( Some ( Arc :: new ( AggregateUDF :: new_from_impl ( udf) ) ) )
759+ } )
760+ }
761+
762+ fn encode_python_agg_udf ( py : Python < ' _ > , udf : & PythonFunctionAggregateUDF ) -> PyResult < Vec < u8 > > {
763+ let signature = AggregateUDFImpl :: signature ( udf) ;
764+ let input_dtypes = signature_input_dtypes ( signature, "PythonFunctionAggregateUDF" ) ?;
765+ let input_schema_bytes = build_input_schema_bytes ( & input_dtypes) ?;
766+ let return_field = Field :: new ( "result" , udf. return_type ( ) . clone ( ) , true ) ;
767+ let return_schema_bytes = build_single_field_schema_bytes ( & return_field) ?;
768+ let state_fields: Vec < Field > = udf
769+ . state_fields_ref ( )
770+ . iter ( )
771+ . map ( |f| f. as_ref ( ) . clone ( ) )
772+ . collect ( ) ;
773+ let state_schema_bytes = build_schema_bytes ( state_fields) ?;
774+ let volatility = volatility_wire_str ( signature. volatility ) ;
775+
776+ let payload = PyTuple :: new (
777+ py,
778+ [
779+ AggregateUDFImpl :: name ( udf) . into_pyobject ( py) ?. into_any ( ) ,
780+ udf. accumulator ( ) . bind ( py) . clone ( ) . into_any ( ) ,
781+ PyBytes :: new ( py, & input_schema_bytes) . into_any ( ) ,
782+ PyBytes :: new ( py, & return_schema_bytes) . into_any ( ) ,
783+ PyBytes :: new ( py, & state_schema_bytes) . into_any ( ) ,
784+ volatility. into_pyobject ( py) ?. into_any ( ) ,
785+ ] ,
786+ ) ?;
787+
788+ cloudpickle ( py) ?
789+ . call_method1 ( "dumps" , ( payload, ) ) ?
790+ . extract :: < Vec < u8 > > ( )
791+ }
792+
793+ fn decode_python_agg_udf ( py : Python < ' _ > , payload : & [ u8 ] ) -> PyResult < PythonFunctionAggregateUDF > {
794+ let tuple = cloudpickle ( py) ?
795+ . call_method1 ( "loads" , ( PyBytes :: new ( py, payload) , ) ) ?
796+ . cast_into :: < PyTuple > ( ) ?;
797+
798+ let name: String = tuple. get_item ( 0 ) ?. extract ( ) ?;
799+ let accumulator: Py < PyAny > = tuple. get_item ( 1 ) ?. unbind ( ) ;
800+ let input_schema_bytes: Vec < u8 > = tuple. get_item ( 2 ) ?. extract ( ) ?;
801+ let return_schema_bytes: Vec < u8 > = tuple. get_item ( 3 ) ?. extract ( ) ?;
802+ let state_schema_bytes: Vec < u8 > = tuple. get_item ( 4 ) ?. extract ( ) ?;
803+ let volatility_str: String = tuple. get_item ( 5 ) ?. extract ( ) ?;
804+
805+ let input_types = read_input_dtypes ( & input_schema_bytes) ?;
806+ let return_type = read_single_return_field ( & return_schema_bytes, "PythonFunctionAggregateUDF" ) ?
807+ . data_type ( )
808+ . clone ( ) ;
809+ // Preserve the encoded state field metadata (names, nullability,
810+ // arbitrary key/value attributes) so the post-decode UDF reports
811+ // the same state schema as the sender's instance — important for
812+ // accumulators whose `StateFieldsArgs` consumers key off names or
813+ // nullability rather than positional `DataType`.
814+ let state_schema = schema_from_ipc_bytes ( & state_schema_bytes) . map_err ( arrow_to_py_err) ?;
815+ let state_fields: Vec < arrow:: datatypes:: FieldRef > =
816+ state_schema. fields ( ) . iter ( ) . cloned ( ) . collect ( ) ;
817+ let volatility = parse_volatility_str ( & volatility_str) ?;
818+
819+ Ok ( PythonFunctionAggregateUDF :: from_parts (
820+ name,
821+ accumulator,
822+ input_types,
823+ return_type,
824+ state_fields,
825+ volatility,
826+ ) )
827+ }
828+
592829#[ cfg( test) ]
593830mod wire_header_tests {
594831 use super :: * ;
@@ -635,12 +872,23 @@ mod wire_header_tests {
635872 #[ test]
636873 fn write_then_strip_round_trips_payload ( ) {
637874 let mut buf = Vec :: new ( ) ;
638- write_wire_header ( & mut buf, PY_SCALAR_UDF_FAMILY ) ;
639- buf. extend_from_slice ( b"scalar -payload" ) ;
875+ write_wire_header ( & mut buf, PY_AGG_UDF_FAMILY ) ;
876+ buf. extend_from_slice ( b"agg -payload" ) ;
640877
641- let payload = strip_wire_header ( & buf, PY_SCALAR_UDF_FAMILY , "scalar UDF" )
878+ let payload = strip_wire_header ( & buf, PY_AGG_UDF_FAMILY , "aggregate UDF" )
642879 . unwrap ( )
643880 . unwrap ( ) ;
644- assert_eq ! ( payload, b"scalar-payload" ) ;
881+ assert_eq ! ( payload, b"agg-payload" ) ;
882+ }
883+
884+ #[ test]
885+ fn strip_does_not_match_a_different_family ( ) {
886+ let mut buf = Vec :: new ( ) ;
887+ write_wire_header ( & mut buf, PY_SCALAR_UDF_FAMILY ) ;
888+ buf. extend_from_slice ( b"payload" ) ;
889+ assert ! ( matches!(
890+ strip_wire_header( & buf, PY_WINDOW_UDF_FAMILY , "window UDF" ) ,
891+ Ok ( None )
892+ ) ) ;
645893 }
646894}
0 commit comments