From d49ade6e820d7d74101aed4a4e57683e3f8febe5 Mon Sep 17 00:00:00 2001 From: Thomas Pellissier-Tanon Date: Thu, 18 Dec 2025 15:30:15 +0100 Subject: [PATCH] Introspection: generate nested classes for complex enums --- newsfragments/5708.added.md | 1 + pyo3-introspection/src/introspection.rs | 11 +-- pyo3-introspection/src/model.rs | 1 + pyo3-introspection/src/stubs.rs | 8 +- pyo3-macros-backend/src/introspection.rs | 7 ++ pyo3-macros-backend/src/pyclass.rs | 108 +++++++++++++---------- pytests/stubs/enums.pyi | 45 +++++++++- 7 files changed, 122 insertions(+), 59 deletions(-) create mode 100644 newsfragments/5708.added.md diff --git a/newsfragments/5708.added.md b/newsfragments/5708.added.md new file mode 100644 index 00000000000..f5d9fd43ba0 --- /dev/null +++ b/newsfragments/5708.added.md @@ -0,0 +1 @@ +Introspection: generate nested classes for complex enums \ No newline at end of file diff --git a/pyo3-introspection/src/introspection.rs b/pyo3-introspection/src/introspection.rs index 65b7115e60d..fa465d66e71 100644 --- a/pyo3-introspection/src/introspection.rs +++ b/pyo3-introspection/src/introspection.rs @@ -32,7 +32,8 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { let mut chunks_by_parent = HashMap::<&str, Vec<&Chunk>>::new(); for chunk in chunks { let (id, parent) = match chunk { - Chunk::Module { id, .. } | Chunk::Class { id, .. } => (Some(id.as_str()), None), + Chunk::Module { id, .. } => (Some(id.as_str()), None), + Chunk::Class { id, parent, .. } => (Some(id.as_str()), parent.as_deref()), Chunk::Function { id, parent, .. } | Chunk::Attribute { id, parent, .. } => { (id.as_deref(), parent.as_deref()) } @@ -129,6 +130,7 @@ fn convert_members<'a>( id, bases, decorators, + parent: _, } => classes.push(convert_class( id, name, @@ -202,10 +204,6 @@ fn convert_class( nested_modules.is_empty(), "Classes cannot contain nested modules" ); - ensure!( - nested_classes.is_empty(), - "Nested classes are not supported yet" - ); Ok(Class { name: name.into(), bases: bases @@ -218,6 +216,7 @@ fn convert_class( .iter() .map(convert_python_identifier) .collect::>()?, + inner_classes: nested_classes, }) } @@ -472,6 +471,8 @@ enum Chunk { bases: Vec, #[serde(default)] decorators: Vec, + #[serde(default)] + parent: Option, }, Function { #[serde(default)] diff --git a/pyo3-introspection/src/model.rs b/pyo3-introspection/src/model.rs index 9fba19c5470..4930f40c56d 100644 --- a/pyo3-introspection/src/model.rs +++ b/pyo3-introspection/src/model.rs @@ -16,6 +16,7 @@ pub struct Class { pub attributes: Vec, /// decorator like 'typing.final' pub decorators: Vec, + pub inner_classes: Vec, } #[derive(Debug, Eq, PartialEq, Clone, Hash)] diff --git a/pyo3-introspection/src/stubs.rs b/pyo3-introspection/src/stubs.rs index f8b094039b3..0e6bb9c9aa1 100644 --- a/pyo3-introspection/src/stubs.rs +++ b/pyo3-introspection/src/stubs.rs @@ -138,7 +138,7 @@ fn class_stubs(class: &Class, imports: &Imports) -> String { buffer.push(')'); } buffer.push(':'); - if class.methods.is_empty() && class.attributes.is_empty() { + if class.methods.is_empty() && class.attributes.is_empty() && class.inner_classes.is_empty() { buffer.push_str(" ..."); return buffer; } @@ -152,6 +152,11 @@ fn class_stubs(class: &Class, imports: &Imports) -> String { buffer.push_str("\n "); buffer.push_str(&function_stubs(method, imports).replace('\n', "\n ")); } + for inner_class in &class.inner_classes { + // We do the indentation + buffer.push_str("\n "); + buffer.push_str(&class_stubs(inner_class, imports).replace('\n', "\n ")); + } buffer } @@ -690,6 +695,7 @@ mod tests { module: Some("typing".into()), name: "final".into(), }], + inner_classes: Vec::new(), }], functions: vec![Function { name: String::new(), diff --git a/pyo3-macros-backend/src/introspection.rs b/pyo3-macros-backend/src/introspection.rs index 4651a43ae54..6389e2ed365 100644 --- a/pyo3-macros-backend/src/introspection.rs +++ b/pyo3-macros-backend/src/introspection.rs @@ -62,6 +62,7 @@ pub fn class_introspection_code( name: &str, extends: Option, is_final: bool, + parent: Option<&Type>, ) -> TokenStream { let mut desc = HashMap::from([ ("type", IntrospectionNode::String("class".into())), @@ -80,6 +81,12 @@ pub fn class_introspection_code( IntrospectionNode::List(vec![PythonTypeHint::module_attr("typing", "final").into()]), ); } + if let Some(parent) = parent { + desc.insert( + "parent", + IntrospectionNode::IntrospectionId(Some(Cow::Borrowed(parent))), + ); + } IntrospectionNode::Map(desc).emit(pyo3_crate_path) } diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index e39e74fdeb3..073702772dd 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -456,7 +456,7 @@ fn impl_class( field_options: Vec<(&syn::Field, FieldPyO3Options)>, methods_type: PyClassMethodsType, ctx: &Ctx, -) -> syn::Result { +) -> Result { let Ctx { pyo3_path, .. } = ctx; let pytypeinfo_impl = impl_pytypeinfo(cls, args, ctx); @@ -499,9 +499,18 @@ fn impl_class( slots.extend(default_hash_slot); slots.extend(default_str_slot); - let py_class_impl = PyClassImplsBuilder::new(cls, args, methods_type, default_methods, slots) - .doc(doc) - .impl_all(ctx)?; + let impl_builder = + PyClassImplsBuilder::new(cls, cls, args, methods_type, default_methods, slots).doc(doc); + let py_class_impl: TokenStream = [ + impl_builder.impl_pyclass(ctx), + impl_builder.impl_into_py(ctx), + impl_builder.impl_pyclassimpl(ctx)?, + impl_builder.impl_add_to_module(ctx), + impl_builder.impl_freelist(ctx), + impl_builder.impl_introspection(ctx, None), + ] + .into_iter() + .collect(); Ok(quote! { impl #pyo3_path::types::DerefToPyAny for #cls {} @@ -1023,6 +1032,7 @@ fn impl_simple_enum( default_slots.extend(default_str_slot); let impl_builder = PyClassImplsBuilder::new( + cls, cls, args, methods_type, @@ -1038,11 +1048,7 @@ fn impl_simple_enum( .doc(doc); let enum_into_pyobject_impl = { - let output_type = if cfg!(feature = "experimental-inspect") { - quote!(const OUTPUT_TYPE: #pyo3_path::inspect::TypeHint = <#cls as #pyo3_path::PyTypeInfo>::TYPE_HINT;) - } else { - TokenStream::new() - }; + let output_type = get_conversion_type_hint(ctx, &format_ident!("OUTPUT_TYPE"), cls); let num = variants.len(); let i = (0..num).map(proc_macro2::Literal::usize_unsuffixed); @@ -1083,7 +1089,7 @@ fn impl_simple_enum( impl_builder.impl_pyclassimpl(ctx)?, impl_builder.impl_add_to_module(ctx), impl_builder.impl_freelist(ctx), - impl_builder.impl_introspection(ctx), + impl_builder.impl_introspection(ctx, None), ] .into_iter() .collect(); @@ -1144,6 +1150,7 @@ fn impl_complex_enum( default_slots.extend(default_str_slot); let impl_builder = PyClassImplsBuilder::new( + cls, cls, &args, methods_type, @@ -1197,7 +1204,7 @@ fn impl_complex_enum( impl_builder.impl_pyclassimpl(ctx)?, impl_builder.impl_add_to_module(ctx), impl_builder.impl_freelist(ctx), - impl_builder.impl_introspection(ctx), + impl_builder.impl_introspection(ctx, None), ] .into_iter() .collect(); @@ -1207,7 +1214,8 @@ fn impl_complex_enum( let mut variant_cls_pyclass_impls = vec![]; let mut variant_cls_impls = vec![]; for variant in variants { - let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.get_ident()); + let variant_name = variant.get_ident().clone(); + let variant_cls = gen_complex_enum_variant_class_ident(cls, &variant_name); let variant_cls_zst = quote! { #[doc(hidden)] @@ -1237,14 +1245,24 @@ fn impl_complex_enum( let variant_new = complex_enum_variant_new(cls, variant, ctx)?; slots.push(variant_new); - let pyclass_impl = PyClassImplsBuilder::new( + let impl_builder = PyClassImplsBuilder::new( &variant_cls, + &variant_name, &variant_args, methods_type, field_getters, slots, - ) - .impl_all(ctx)?; + ); + let pyclass_impl: TokenStream = [ + impl_builder.impl_pyclass(ctx), + impl_builder.impl_into_py(ctx), + impl_builder.impl_pyclassimpl(ctx)?, + impl_builder.impl_add_to_module(ctx), + impl_builder.impl_freelist(ctx), + impl_builder.impl_introspection(ctx, Some(cls)), + ] + .into_iter() + .collect(); variant_cls_pyclass_impls.push(pyclass_impl); } @@ -1559,7 +1577,7 @@ fn impl_complex_enum_tuple_variant_cls( Ok((cls_impl, field_getters, slots)) } -fn gen_complex_enum_variant_class_ident(enum_: &syn::Ident, variant: &syn::Ident) -> syn::Ident { +fn gen_complex_enum_variant_class_ident(enum_: &Ident, variant: &Ident) -> Ident { format_ident!("{}_{}", enum_, variant) } @@ -1713,7 +1731,7 @@ pub fn gen_complex_enum_variant_attr( let wrapper_ident = format_ident!("__pymethod_variant_cls_{}__", member); let python_name = spec.null_terminated_python_name(); - let variant_cls = format_ident!("{}_{}", cls, member); + let variant_cls = gen_complex_enum_variant_class_ident(cls, member); let associated_method = quote! { fn #wrapper_ident(py: #pyo3_path::Python<'_>) -> #pyo3_path::PyResult<#pyo3_path::Py<#pyo3_path::PyAny>> { ::std::result::Result::Ok(py.get_type::<#variant_cls>().into_any().unbind()) @@ -1756,7 +1774,7 @@ fn complex_enum_struct_variant_new<'a>( ctx: &Ctx, ) -> Result { let Ctx { pyo3_path, .. } = ctx; - let variant_cls = format_ident!("{}_{}", cls, variant.ident); + let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.ident); let variant_cls_type: syn::Type = parse_quote!(#variant_cls); let arg_py_ident: syn::Ident = parse_quote!(py); @@ -1817,7 +1835,7 @@ fn complex_enum_tuple_variant_new<'a>( ) -> Result { let Ctx { pyo3_path, .. } = ctx; - let variant_cls: Ident = format_ident!("{}_{}", cls, variant.ident); + let variant_cls = gen_complex_enum_variant_class_ident(cls, variant.ident); let variant_cls_type: syn::Type = parse_quote!(#variant_cls); let arg_py_ident: syn::Ident = parse_quote!(py); @@ -1995,7 +2013,7 @@ fn descriptors_to_items( Ok(items) } -fn impl_pytypeinfo(cls: &syn::Ident, attr: &PyClassArgs, ctx: &Ctx) -> TokenStream { +fn impl_pytypeinfo(cls: &Ident, attr: &PyClassArgs, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path, .. } = ctx; #[cfg(feature = "experimental-inspect")] @@ -2320,7 +2338,10 @@ fn pyclass_class_getitem( /// and attributes of `#[pyclass]`, and docstrings. /// Therefore it doesn't implement traits that depends on struct fields and enum variants. struct PyClassImplsBuilder<'a> { - cls: &'a syn::Ident, + /// Identifier of the class Rust struct + cls_ident: &'a Ident, + /// Name of the class in Python + cls_name: &'a Ident, attr: &'a PyClassArgs, methods_type: PyClassMethodsType, default_methods: Vec, @@ -2330,14 +2351,16 @@ struct PyClassImplsBuilder<'a> { impl<'a> PyClassImplsBuilder<'a> { fn new( - cls: &'a syn::Ident, + cls_ident: &'a Ident, + cls_name: &'a Ident, attr: &'a PyClassArgs, methods_type: PyClassMethodsType, default_methods: Vec, default_slots: Vec, ) -> Self { Self { - cls, + cls_ident, + cls_name, attr, methods_type, default_methods, @@ -2353,24 +2376,11 @@ impl<'a> PyClassImplsBuilder<'a> { } } - fn impl_all(&self, ctx: &Ctx) -> Result { - Ok([ - self.impl_pyclass(ctx), - self.impl_into_py(ctx), - self.impl_pyclassimpl(ctx)?, - self.impl_add_to_module(ctx), - self.impl_freelist(ctx), - self.impl_introspection(ctx), - ] - .into_iter() - .collect()) - } - fn impl_pyclass(&self, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path, .. } = ctx; - let cls = self.cls; + let cls = self.cls_ident; - let cls_name = get_class_python_name(cls, self.attr).to_string(); + let cls_name = get_class_python_name(self.cls_name, self.attr).to_string(); let frozen = if self.attr.options.frozen.is_some() { quote! { #pyo3_path::pyclass::boolean_struct::True } @@ -2388,7 +2398,7 @@ impl<'a> PyClassImplsBuilder<'a> { fn impl_into_py(&self, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path, .. } = ctx; - let cls = self.cls; + let cls = self.cls_ident; let attr = self.attr; // If #cls is not extended type, we allow Self->PyObject conversion if attr.options.extends.is_none() { @@ -2414,7 +2424,7 @@ impl<'a> PyClassImplsBuilder<'a> { } fn impl_pyclassimpl(&self, ctx: &Ctx) -> Result { let Ctx { pyo3_path, .. } = ctx; - let cls = self.cls; + let cls = self.cls_ident; let doc = self .doc .as_ref() @@ -2438,7 +2448,7 @@ impl<'a> PyClassImplsBuilder<'a> { ensure_spanned!( !(is_mapping && is_sequence), - self.cls.span() => "a `#[pyclass]` cannot be both a `mapping` and a `sequence`" + cls.span() => "a `#[pyclass]` cannot be both a `mapping` and a `sequence`" ); let dict_offset = if self.attr.options.dict.is_some() { @@ -2515,7 +2525,6 @@ impl<'a> PyClassImplsBuilder<'a> { } }; - let cls = self.cls; let attr = self.attr; let dict = if attr.options.dict.is_some() { quote! { #pyo3_path::impl_::pyclass::PyClassDictSlot } @@ -2668,7 +2677,7 @@ impl<'a> PyClassImplsBuilder<'a> { fn impl_add_to_module(&self, ctx: &Ctx) -> TokenStream { let Ctx { pyo3_path, .. } = ctx; - let cls = self.cls; + let cls = self.cls_ident; quote! { impl #cls { #[doc(hidden)] @@ -2678,7 +2687,7 @@ impl<'a> PyClassImplsBuilder<'a> { } fn impl_freelist(&self, ctx: &Ctx) -> TokenStream { - let cls = self.cls; + let cls = self.cls_ident; let Ctx { pyo3_path, .. } = ctx; self.attr.options.freelist.as_ref().map_or(quote! {}, |freelist| { @@ -2697,7 +2706,7 @@ impl<'a> PyClassImplsBuilder<'a> { fn freelist_slots(&self, ctx: &Ctx) -> Vec { let Ctx { pyo3_path, .. } = ctx; - let cls = self.cls; + let cls = self.cls_ident; if self.attr.options.freelist.is_some() { vec![ @@ -2720,10 +2729,10 @@ impl<'a> PyClassImplsBuilder<'a> { } #[cfg(feature = "experimental-inspect")] - fn impl_introspection(&self, ctx: &Ctx) -> TokenStream { + fn impl_introspection(&self, ctx: &Ctx, parent: Option<&Ident>) -> TokenStream { let Ctx { pyo3_path, .. } = ctx; - let name = get_class_python_name(self.cls, self.attr).to_string(); - let ident = self.cls; + let name = get_class_python_name(self.cls_name, self.attr).to_string(); + let ident = self.cls_ident; let static_introspection = class_introspection_code( pyo3_path, ident, @@ -2739,6 +2748,7 @@ impl<'a> PyClassImplsBuilder<'a> { ) }), self.attr.options.subclass.is_none(), + parent.map(|p| parse_quote!(#p)).as_ref(), ); let introspection_id = introspection_id_const(); quote! { @@ -2750,7 +2760,7 @@ impl<'a> PyClassImplsBuilder<'a> { } #[cfg(not(feature = "experimental-inspect"))] - fn impl_introspection(&self, _ctx: &Ctx) -> TokenStream { + fn impl_introspection(&self, _ctx: &Ctx, _parent: Option<&Ident>) -> TokenStream { quote! {} } } diff --git a/pytests/stubs/enums.pyi b/pytests/stubs/enums.pyi index ec43dca52f8..f771332f20a 100644 --- a/pytests/stubs/enums.pyi +++ b/pytests/stubs/enums.pyi @@ -1,15 +1,52 @@ from typing import final -class ComplexEnum: ... -class MixedComplexEnum: ... +class ComplexEnum: + @final + class EmptyStruct(ComplexEnum): ... + + @final + class Float(ComplexEnum): ... + + @final + class Int(ComplexEnum): ... + + @final + class MultiFieldStruct(ComplexEnum): ... + + @final + class Str(ComplexEnum): ... + + @final + class VariantWithDefault(ComplexEnum): ... + +class MixedComplexEnum: + @final + class Empty(MixedComplexEnum): ... + + @final + class Nothing(MixedComplexEnum): ... @final class SimpleEnum: def __eq__(self, /, other: SimpleEnum | int) -> bool: ... def __ne__(self, /, other: SimpleEnum | int) -> bool: ... -class SimpleTupleEnum: ... -class TupleEnum: ... +class SimpleTupleEnum: + @final + class Int(SimpleTupleEnum): ... + + @final + class Str(SimpleTupleEnum): ... + +class TupleEnum: + @final + class EmptyTuple(TupleEnum): ... + + @final + class Full(TupleEnum): ... + + @final + class FullWithDefault(TupleEnum): ... def do_complex_stuff(thing: ComplexEnum) -> ComplexEnum: ... def do_mixed_complex_stuff(thing: MixedComplexEnum) -> MixedComplexEnum: ...