Skip to content

Commit 6bf1ab6

Browse files
authored
Initiailzer for PyCStructure (RustPython#6586)
1 parent ca1c4c1 commit 6bf1ab6

File tree

2 files changed

+94
-54
lines changed

2 files changed

+94
-54
lines changed

crates/vm/src/stdlib/ctypes/structure.rs

Lines changed: 89 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ impl PyCStructType {
368368

369369
// Store StgInfo with aligned size and total alignment
370370
let mut stg_info = StgInfo::new(aligned_size, total_align);
371+
stg_info.length = fields.len();
371372
stg_info.format = Some(format);
372373
stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL; // Mark as finalized
373374
if has_pointer {
@@ -511,7 +512,7 @@ impl Debug for PyCStructure {
511512
impl Constructor for PyCStructure {
512513
type Args = FuncArgs;
513514

514-
fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
515+
fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
515516
// Check for abstract class and extract values in a block to drop the borrow
516517
let (total_size, total_align, length) = {
517518
let stg_info = cls.stg_info(vm)?;
@@ -523,79 +524,116 @@ impl Constructor for PyCStructure {
523524
stg_info_mut.flags |= StgInfoFlags::DICTFLAG_FINAL;
524525
}
525526

526-
// Get _fields_ from the class using get_attr to properly search MRO
527-
let fields_attr = cls.as_object().get_attr("_fields_", vm).ok();
527+
// Initialize buffer with zeros using computed size
528+
let mut new_stg_info = StgInfo::new(total_size, total_align);
529+
new_stg_info.length = length;
530+
PyCStructure(PyCData::from_stg_info(&new_stg_info))
531+
.into_ref_with_type(vm, cls)
532+
.map(Into::into)
533+
}
534+
535+
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
536+
unimplemented!("use slot_new")
537+
}
538+
}
539+
540+
impl PyCStructure {
541+
/// Recursively initialize positional arguments through inheritance chain
542+
/// Returns the number of arguments consumed
543+
fn init_pos_args(
544+
self_obj: &Py<Self>,
545+
type_obj: &Py<PyType>,
546+
args: &[PyObjectRef],
547+
kwargs: &indexmap::IndexMap<String, PyObjectRef>,
548+
index: usize,
549+
vm: &VirtualMachine,
550+
) -> PyResult<usize> {
551+
let mut current_index = index;
528552

529-
// Collect field names for initialization
530-
let mut field_names: Vec<String> = Vec::new();
531-
if let Some(fields_attr) = fields_attr {
532-
let fields: Vec<PyObjectRef> = if let Some(list) = fields_attr.downcast_ref::<PyList>()
553+
// 1. First process base class fields recursively
554+
let base_clone = {
555+
let bases = type_obj.bases.read();
556+
if let Some(base) = bases.first()
557+
&& base.stg_info_opt().is_some()
533558
{
534-
list.borrow_vec().to_vec()
535-
} else if let Some(tuple) = fields_attr.downcast_ref::<PyTuple>() {
536-
tuple.to_vec()
559+
Some(base.clone())
537560
} else {
538-
vec![]
539-
};
561+
None
562+
}
563+
};
564+
565+
if let Some(ref base) = base_clone {
566+
current_index = Self::init_pos_args(self_obj, base, args, kwargs, current_index, vm)?;
567+
}
568+
569+
// 2. Process this class's _fields_
570+
if let Some(fields_attr) = type_obj.get_direct_attr(vm.ctx.intern_str("_fields_")) {
571+
let fields: Vec<PyObjectRef> = fields_attr.try_to_value(vm)?;
540572

541573
for field in fields.iter() {
542-
let Some(field_tuple) = field.downcast_ref::<PyTuple>() else {
543-
continue;
544-
};
545-
if field_tuple.len() < 2 {
546-
continue;
574+
if current_index >= args.len() {
575+
break;
547576
}
548-
if let Some(name) = field_tuple.first().unwrap().downcast_ref::<PyStr>() {
549-
field_names.push(name.to_string());
577+
if let Some(tuple) = field.downcast_ref::<PyTuple>()
578+
&& let Some(name) = tuple.first()
579+
&& let Some(name_str) = name.downcast_ref::<PyStr>()
580+
{
581+
let field_name = name_str.as_str().to_owned();
582+
// Check for duplicate in kwargs
583+
if kwargs.contains_key(&field_name) {
584+
return Err(vm.new_type_error(format!(
585+
"duplicate values for field {:?}",
586+
field_name
587+
)));
588+
}
589+
self_obj.as_object().set_attr(
590+
vm.ctx.intern_str(field_name),
591+
args[current_index].clone(),
592+
vm,
593+
)?;
594+
current_index += 1;
550595
}
551596
}
552597
}
553598

554-
// Initialize buffer with zeros using computed size
555-
let mut stg_info = StgInfo::new(total_size, total_align);
556-
stg_info.length = if length > 0 {
557-
length
558-
} else {
559-
field_names.len()
560-
};
561-
stg_info.paramfunc = super::base::ParamFunc::Structure;
562-
let instance = PyCStructure(PyCData::from_stg_info(&stg_info));
599+
Ok(current_index)
600+
}
601+
}
563602

564-
// Handle keyword arguments for field initialization
565-
let py_instance = instance.into_ref_with_type(vm, cls.clone())?;
566-
let py_obj: PyObjectRef = py_instance.clone().into();
603+
impl Initializer for PyCStructure {
604+
type Args = FuncArgs;
567605

568-
// Set field values from kwargs using standard attribute setting
569-
for (key, value) in args.kwargs.iter() {
570-
if field_names.iter().any(|n| n == key.as_str()) {
571-
py_obj.set_attr(vm.ctx.intern_str(key.as_str()), value.clone(), vm)?;
606+
fn init(zelf: crate::PyRef<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
607+
// Struct_init: handle positional and keyword arguments
608+
let cls = zelf.class().to_owned();
609+
610+
// 1. Process positional arguments recursively through inheritance chain
611+
if !args.args.is_empty() {
612+
let consumed =
613+
PyCStructure::init_pos_args(&zelf, &cls, &args.args, &args.kwargs, 0, vm)?;
614+
615+
if consumed < args.args.len() {
616+
return Err(vm.new_type_error("too many initializers"));
572617
}
573618
}
574619

575-
// Set field values from positional args
576-
if args.args.len() > field_names.len() {
577-
return Err(vm.new_type_error("too many initializers".to_string()));
578-
}
579-
for (i, value) in args.args.iter().enumerate() {
580-
py_obj.set_attr(
581-
vm.ctx.intern_str(field_names[i].as_str()),
582-
value.clone(),
583-
vm,
584-
)?;
620+
// 2. Process keyword arguments
621+
for (key, value) in args.kwargs.iter() {
622+
zelf.as_object()
623+
.set_attr(vm.ctx.intern_str(key.as_str()), value.clone(), vm)?;
585624
}
586625

587-
Ok(py_instance.into())
588-
}
589-
590-
fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
591-
unimplemented!("use slot_new")
626+
Ok(())
592627
}
593628
}
594629

595630
// Note: GetAttr and SetAttr are not implemented here.
596631
// Field access is handled by CField descriptors registered on the class.
597632

598-
#[pyclass(flags(BASETYPE, IMMUTABLETYPE), with(Constructor, AsBuffer))]
633+
#[pyclass(
634+
flags(BASETYPE, IMMUTABLETYPE),
635+
with(Constructor, Initializer, AsBuffer)
636+
)]
599637
impl PyCStructure {
600638
#[pygetset]
601639
fn _b0_(&self) -> Option<PyObjectRef> {

crates/vm/src/stdlib/ctypes/union.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ impl PyCUnionType {
273273

274274
// Store StgInfo with aligned size
275275
let mut stg_info = StgInfo::new(aligned_size, total_align);
276+
stg_info.length = fields.len();
276277
stg_info.flags |= StgInfoFlags::DICTFLAG_FINAL | StgInfoFlags::TYPEFLAG_HASUNION;
277278
// PEP 3118 doesn't support union. Use 'B' for bytes.
278279
stg_info.format = Some("B".to_string());
@@ -431,9 +432,9 @@ impl Constructor for PyCUnion {
431432

432433
fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
433434
// Check for abstract class and extract values in a block to drop the borrow
434-
let (total_size, total_align) = {
435+
let (total_size, total_align, length) = {
435436
let stg_info = cls.stg_info(vm)?;
436-
(stg_info.size, stg_info.align)
437+
(stg_info.size, stg_info.align, stg_info.length)
437438
};
438439

439440
// Mark the class as finalized (instance creation finalizes the type)
@@ -442,7 +443,8 @@ impl Constructor for PyCUnion {
442443
}
443444

444445
// Initialize buffer with zeros using computed size
445-
let new_stg_info = StgInfo::new(total_size, total_align);
446+
let mut new_stg_info = StgInfo::new(total_size, total_align);
447+
new_stg_info.length = length;
446448
PyCUnion(PyCData::from_stg_info(&new_stg_info))
447449
.into_ref_with_type(vm, cls)
448450
.map(Into::into)

0 commit comments

Comments
 (0)