diff --git a/impl/src/expand.rs b/impl/src/expand.rs index c693921..cc66736 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -174,12 +174,23 @@ fn impl_struct(input: Struct) -> TokenStream { #ty #body } }; - let from_impl = quote_spanned! {span=> + let mut from_impl = quote_spanned! {span=> #[automatically_derived] impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { #from_function } }; + if let Some(from) = type_parameter_of_box(from_field.ty) { + let body = from_some_source(from_field, backtrace_field, quote!(::thiserror::__private::Box::new(#source_var))); + from_impl.extend(quote_spanned! {span=> + #[automatically_derived] + impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { + fn from(#source_var: #from) -> Self { + #ty #body + } + } + }); + } Some(quote! { #[allow( deprecated, @@ -449,12 +460,23 @@ fn impl_enum(input: Enum) -> TokenStream { #ty::#variant #body } }; - let from_impl = quote_spanned! {span=> + let mut from_impl = quote_spanned! {span=> #[automatically_derived] impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause { #from_function } }; + if let Some(boxed) = type_parameter_of_box(from_field.ty) { + let body = from_some_source(from_field, backtrace_field, quote!(::thiserror::__private::Box::new(#source_var))); + from_impl.extend(quote_spanned! {span=> + #[automatically_derived] + impl #impl_generics ::core::convert::From<#boxed> for #ty #ty_generics #where_clause { + fn from(#source_var: #boxed) -> Self { + #ty::#variant #body + } + } + }); + } Some(quote! { #[allow( deprecated, @@ -523,12 +545,20 @@ fn from_initializer( backtrace_field: Option<&Field>, source_var: &Ident, ) -> TokenStream { - let from_member = &from_field.member; let some_source = if type_is_option(from_field.ty) { quote!(::core::option::Option::Some(#source_var)) } else { quote!(#source_var) }; + from_some_source(from_field, backtrace_field, some_source) +} + +fn from_some_source( + from_field: &Field, + backtrace_field: Option<&Field<'_>>, + some_source: TokenStream, +) -> TokenStream { + let from_member = &from_field.member; let backtrace = backtrace_field.map(|backtrace_field| { let backtrace_member = &backtrace_field.member; if type_is_option(backtrace_field.ty) { @@ -581,3 +611,29 @@ fn type_parameter_of_option(ty: &Type) -> Option<&Type> { _ => None, } } + +fn type_parameter_of_box(ty: &Type) -> Option<&Type> { + let path = match ty { + Type::Path(ty) => &ty.path, + _ => return None, + }; + + let last = path.segments.last().unwrap(); + if last.ident != "Box" { + return None; + } + + let bracketed = match &last.arguments { + PathArguments::AngleBracketed(bracketed) => bracketed, + _ => return None, + }; + + if bracketed.args.len() != 1 { + return None; + } + + match &bracketed.args[0] { + GenericArgument::Type(arg) => Some(arg), + _ => None, + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 638eddd..fd069b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -273,6 +273,8 @@ #[cfg(all(thiserror_nightly_testing, not(error_generic_member_access)))] compile_error!("Build script probe failed to compile."); +#[cfg(not(feature = "std"))] +extern crate alloc; #[cfg(feature = "std")] extern crate std; #[cfg(feature = "std")] @@ -299,8 +301,14 @@ pub mod __private { #[doc(hidden)] pub use crate::var::Var; #[doc(hidden)] + #[cfg(not(feature = "std"))] + pub use alloc::boxed::Box; + #[doc(hidden)] pub use core::error::Error; #[cfg(all(feature = "std", not(thiserror_no_backtrace_type)))] #[doc(hidden)] pub use std::backtrace::Backtrace; + #[doc(hidden)] + #[cfg(feature = "std")] + pub use std::boxed::Box; } diff --git a/tests/test_from.rs b/tests/test_from.rs index 51af40b..22f47da 100644 --- a/tests/test_from.rs +++ b/tests/test_from.rs @@ -43,6 +43,15 @@ pub enum ErrorEnumOptional { }, } +#[derive(Error, Debug)] +#[error("...")] +pub enum ErrorEnumBox { + Test { + #[from] + source: Box, + }, +} + #[derive(Error, Debug)] #[error("...")] pub enum Many { @@ -51,6 +60,7 @@ pub enum Many { } fn assert_impl>() {} +fn assert_impl_box>>() {} #[test] fn test_from() { @@ -60,5 +70,7 @@ fn test_from() { assert_impl::(); assert_impl::(); assert_impl::(); + assert_impl::(); + assert_impl_box::(); assert_impl::(); }