Skip to content

Commit b91fa87

Browse files
committed
Add __replace__ and fix __reduce__ for structseq
- Add __replace__ method to PyStructSequence trait - Move __reduce__ from #[pymethod] to extend_pyclass with contains_key guard, allowing per-type overrides - Fix repr: remove trailing comma for single-field sequences
1 parent 5bf13e8 commit b91fa87

File tree

1 file changed

+73
-16
lines changed

1 file changed

+73
-16
lines changed

crates/vm/src/types/structseq.rs

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::{
22
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
33
builtins::{PyBaseExceptionRef, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef},
44
class::{PyClassImpl, StaticType},
5-
function::{Either, PyComparisonValue},
5+
function::{Either, FuncArgs, PyComparisonValue, PyMethodDef, PyMethodFlags},
66
iter::PyExactSizeIterator,
77
protocol::{PyMappingMethods, PySequenceMethods},
88
sliceable::{SequenceIndex, SliceableSequenceOp},
@@ -11,6 +11,15 @@ use crate::{
1111
};
1212
use std::sync::LazyLock;
1313

14+
const DEFAULT_STRUCTSEQ_REDUCE: PyMethodDef = PyMethodDef::new_const(
15+
"__reduce__",
16+
|zelf: PyRef<PyTuple>, vm: &VirtualMachine| -> PyTupleRef {
17+
vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),)))
18+
},
19+
PyMethodFlags::METHOD,
20+
None,
21+
);
22+
1423
/// Create a new struct sequence instance from a sequence.
1524
///
1625
/// The class must have `n_sequence_fields` and `n_fields` attributes set
@@ -206,19 +215,13 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
206215
};
207216
let (body, suffix) =
208217
if let Some(_guard) = rustpython_vm::recursion::ReprGuard::enter(vm, zelf.as_ref()) {
209-
if field_names.len() == 1 {
210-
let value = zelf.first().unwrap();
211-
let formatted = format_field((value, field_names[0]))?;
212-
(formatted, ",")
213-
} else {
214-
let fields: PyResult<Vec<_>> = zelf
215-
.iter()
216-
.map(|value| value.as_ref())
217-
.zip(field_names.iter().copied())
218-
.map(format_field)
219-
.collect();
220-
(fields?.join(", "), "")
221-
}
218+
let fields: PyResult<Vec<_>> = zelf
219+
.iter()
220+
.map(|value| value.as_ref())
221+
.zip(field_names.iter().copied())
222+
.map(format_field)
223+
.collect();
224+
(fields?.join(", "), "")
222225
} else {
223226
(String::new(), "...")
224227
};
@@ -232,8 +235,48 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
232235
}
233236

234237
#[pymethod]
235-
fn __reduce__(zelf: PyRef<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
236-
vm.new_tuple((zelf.class().to_owned(), (vm.ctx.new_tuple(zelf.to_vec()),)))
238+
fn __replace__(zelf: PyRef<PyTuple>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
239+
if !args.args.is_empty() {
240+
return Err(vm.new_type_error("__replace__() takes no positional arguments".to_owned()));
241+
}
242+
243+
if Self::Data::UNNAMED_FIELDS_LEN > 0 {
244+
return Err(vm.new_type_error(format!(
245+
"__replace__() is not supported for {} because it has unnamed field(s)",
246+
zelf.class().slot_name()
247+
)));
248+
}
249+
250+
let n_fields = Self::Data::REQUIRED_FIELD_NAMES.len()
251+
+ Self::Data::OPTIONAL_FIELD_NAMES.len();
252+
let mut items: Vec<PyObjectRef> = zelf.as_slice()[..n_fields].to_vec();
253+
254+
let mut kwargs = args.kwargs.clone();
255+
256+
// Replace fields from kwargs
257+
let all_field_names: Vec<&str> = Self::Data::REQUIRED_FIELD_NAMES
258+
.iter()
259+
.chain(Self::Data::OPTIONAL_FIELD_NAMES.iter())
260+
.copied()
261+
.collect();
262+
for (i, &name) in all_field_names.iter().enumerate() {
263+
if let Some(val) = kwargs.shift_remove(name) {
264+
items[i] = val;
265+
}
266+
}
267+
268+
// Check for unexpected keyword arguments
269+
if !kwargs.is_empty() {
270+
let names: Vec<&str> = kwargs.keys().map(|k| k.as_str()).collect();
271+
return Err(vm.new_type_error(format!(
272+
"Got unexpected field name(s): {:?}",
273+
names
274+
)));
275+
}
276+
277+
PyTuple::new_unchecked(items.into_boxed_slice())
278+
.into_ref_with_type(vm, zelf.class().to_owned())
279+
.map(Into::into)
237280
}
238281

239282
#[pymethod]
@@ -327,6 +370,20 @@ pub trait PyStructSequence: StaticType + PyClassImpl + Sized + 'static {
327370
.slots
328371
.richcompare
329372
.store(Some(struct_sequence_richcompare));
373+
374+
// Default __reduce__: only set if not already overridden by the impl's extend_class.
375+
// This allows struct sequences like sched_param to provide a custom __reduce__
376+
// (equivalent to METH_COEXIST in structseq.c).
377+
if !class
378+
.attributes
379+
.read()
380+
.contains_key(ctx.intern_str("__reduce__"))
381+
{
382+
class.set_attr(
383+
ctx.intern_str("__reduce__"),
384+
DEFAULT_STRUCTSEQ_REDUCE.to_proper_method(class, ctx),
385+
);
386+
}
330387
}
331388
}
332389

0 commit comments

Comments
 (0)