@@ -2,7 +2,7 @@ use crate::{
22 AsObject , Py , PyObject , PyObjectRef , PyPayload , PyRef , PyResult , VirtualMachine , atomic_func,
33 builtins:: { PyBaseExceptionRef , PyStrRef , PyTuple , PyTupleRef , PyType , PyTypeRef } ,
44 class:: { PyClassImpl , StaticType } ,
5- function:: { Either , PyComparisonValue } ,
5+ function:: { Either , FuncArgs , PyComparisonValue , PyMethodDef , PyMethodFlags } ,
66 iter:: PyExactSizeIterator ,
77 protocol:: { PyMappingMethods , PySequenceMethods } ,
88 sliceable:: { SequenceIndex , SliceableSequenceOp } ,
@@ -11,6 +11,15 @@ use crate::{
1111} ;
1212use std:: sync:: LazyLock ;
1313
14+ const DEFAULT_STRUCTSEQ_REDUCE : PyMethodDef = PyMethodDef :: new_const (
15+ "__reduce__" ,
16+ |zelf : PyRef < PyTuple > , vm : & VirtualMachine | -> PyTupleRef {
17+ vm. new_tuple ( ( zelf. class ( ) . to_owned ( ) , ( vm. ctx . new_tuple ( zelf. to_vec ( ) ) , ) ) )
18+ } ,
19+ PyMethodFlags :: METHOD ,
20+ None ,
21+ ) ;
22+
1423/// Create a new struct sequence instance from a sequence.
1524///
1625/// The class must have `n_sequence_fields` and `n_fields` attributes set
@@ -206,19 +215,13 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
206215 } ;
207216 let ( body, suffix) =
208217 if let Some ( _guard) = rustpython_vm:: recursion:: ReprGuard :: enter ( vm, zelf. as_ref ( ) ) {
209- if field_names. len ( ) == 1 {
210- let value = zelf. first ( ) . unwrap ( ) ;
211- let formatted = format_field ( ( value, field_names[ 0 ] ) ) ?;
212- ( formatted, "," )
213- } else {
214- let fields: PyResult < Vec < _ > > = zelf
215- . iter ( )
216- . map ( |value| value. as_ref ( ) )
217- . zip ( field_names. iter ( ) . copied ( ) )
218- . map ( format_field)
219- . collect ( ) ;
220- ( fields?. join ( ", " ) , "" )
221- }
218+ let fields: PyResult < Vec < _ > > = zelf
219+ . iter ( )
220+ . map ( |value| value. as_ref ( ) )
221+ . zip ( field_names. iter ( ) . copied ( ) )
222+ . map ( format_field)
223+ . collect ( ) ;
224+ ( fields?. join ( ", " ) , "" )
222225 } else {
223226 ( String :: new ( ) , "..." )
224227 } ;
@@ -232,8 +235,48 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
232235 }
233236
234237 #[ pymethod]
235- fn __reduce__ ( zelf : PyRef < PyTuple > , vm : & VirtualMachine ) -> PyTupleRef {
236- vm. new_tuple ( ( zelf. class ( ) . to_owned ( ) , ( vm. ctx . new_tuple ( zelf. to_vec ( ) ) , ) ) )
238+ fn __replace__ ( zelf : PyRef < PyTuple > , args : FuncArgs , vm : & VirtualMachine ) -> PyResult {
239+ if !args. args . is_empty ( ) {
240+ return Err ( vm. new_type_error ( "__replace__() takes no positional arguments" . to_owned ( ) ) ) ;
241+ }
242+
243+ if Self :: Data :: UNNAMED_FIELDS_LEN > 0 {
244+ return Err ( vm. new_type_error ( format ! (
245+ "__replace__() is not supported for {} because it has unnamed field(s)" ,
246+ zelf. class( ) . slot_name( )
247+ ) ) ) ;
248+ }
249+
250+ let n_fields = Self :: Data :: REQUIRED_FIELD_NAMES . len ( )
251+ + Self :: Data :: OPTIONAL_FIELD_NAMES . len ( ) ;
252+ let mut items: Vec < PyObjectRef > = zelf. as_slice ( ) [ ..n_fields] . to_vec ( ) ;
253+
254+ let mut kwargs = args. kwargs . clone ( ) ;
255+
256+ // Replace fields from kwargs
257+ let all_field_names: Vec < & str > = Self :: Data :: REQUIRED_FIELD_NAMES
258+ . iter ( )
259+ . chain ( Self :: Data :: OPTIONAL_FIELD_NAMES . iter ( ) )
260+ . copied ( )
261+ . collect ( ) ;
262+ for ( i, & name) in all_field_names. iter ( ) . enumerate ( ) {
263+ if let Some ( val) = kwargs. shift_remove ( name) {
264+ items[ i] = val;
265+ }
266+ }
267+
268+ // Check for unexpected keyword arguments
269+ if !kwargs. is_empty ( ) {
270+ let names: Vec < & str > = kwargs. keys ( ) . map ( |k| k. as_str ( ) ) . collect ( ) ;
271+ return Err ( vm. new_type_error ( format ! (
272+ "Got unexpected field name(s): {:?}" ,
273+ names
274+ ) ) ) ;
275+ }
276+
277+ PyTuple :: new_unchecked ( items. into_boxed_slice ( ) )
278+ . into_ref_with_type ( vm, zelf. class ( ) . to_owned ( ) )
279+ . map ( Into :: into)
237280 }
238281
239282 #[ pymethod]
@@ -327,6 +370,20 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
327370 . slots
328371 . richcompare
329372 . store ( Some ( struct_sequence_richcompare) ) ;
373+
374+ // Default __reduce__: only set if not already overridden by the impl's extend_class.
375+ // This allows struct sequences like sched_param to provide a custom __reduce__
376+ // (equivalent to METH_COEXIST in structseq.c).
377+ if !class
378+ . attributes
379+ . read ( )
380+ . contains_key ( ctx. intern_str ( "__reduce__" ) )
381+ {
382+ class. set_attr (
383+ ctx. intern_str ( "__reduce__" ) ,
384+ DEFAULT_STRUCTSEQ_REDUCE . to_proper_method ( class, ctx) ,
385+ ) ;
386+ }
330387 }
331388}
332389
0 commit comments