diff --git a/sea-orm-macros/src/derives/model.rs b/sea-orm-macros/src/derives/model.rs index 31e7e66c11..d29a69b062 100644 --- a/sea-orm-macros/src/derives/model.rs +++ b/sea-orm-macros/src/derives/model.rs @@ -7,7 +7,7 @@ use itertools::izip; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use std::iter::FromIterator; -use syn::{Attribute, Data, Expr, Ident, LitStr}; +use syn::{Attribute, Data, Expr, Ident, LitStr, Type}; pub(crate) struct DeriveModel { column_idents: Vec, @@ -153,12 +153,12 @@ impl DeriveModel { // In that case we interpret it as "no nested row" (i.e., Option::None). // This check detects that condition by testing if all non-ignored fields are NULL. let all_null_check = { - let checks: Vec<_> = izip!(field_idents, ignore_attrs) - .filter_map(|(field_ident, &ignore)| { + let checks: Vec<_> = izip!(field_idents, field_types, ignore_attrs) + .filter_map(|(field_ident, field_type, &ignore)| { if ignore { None } else { - Some(quote! { #field_ident.is_none() }) + Some(create_is_null_expr(field_ident, field_type)) } }) .collect(); @@ -258,3 +258,67 @@ pub fn expand_derive_model( ) -> syn::Result { DeriveModel::new(ident, data, attrs)?.expand() } + +/// Get the total nesting depth of `Option`. +/// +/// For example: +/// - `Option` => `1` +/// - `Option>` => `2` +/// - `Option>>` => `3` +fn option_nesting_depth(ty: &Type) -> usize { + match ty { + Type::Path(type_path) if type_path.qself.is_none() => type_path + .path + .segments + .last() + .and_then(|segment| { + if segment.ident != "Option" { + return None; + } + + match &segment.arguments { + syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => { + args.args.first().map(|arg| match arg { + syn::GenericArgument::Type(inner) => 1 + option_nesting_depth(inner), + _ => 1, + }) + } + _ => Some(1), + } + }) + .unwrap_or(0), + _ => 0, + } +} + +/// Generate an expr that checks whether an optional field is nullish. +/// +/// For a nested `Option`, the generated expression treats every partially +/// unwrapped `None` as null. +/// +/// For example, for `Option>>`, it will generate: +/// ```rust,ignore +/// matches!( +/// field, +/// None | Some(None) | Some(Some(None)) | Some(Some(Some(None))) +/// ) +/// ``` +fn create_is_null_expr(field_ident: &Ident, field_type: &Type) -> TokenStream { + let depth = option_nesting_depth(field_type); + + if depth == 0 { + return quote! { #field_ident.is_none() }; + } + + let patterns: Vec<_> = (0..=depth) + .map(|depth| { + let mut pattern = quote! { None }; + for _ in 0..depth { + pattern = quote! { Some(#pattern) }; + } + pattern + }) + .collect(); + + quote! { matches!(#field_ident, #( #patterns )|* ) } +} diff --git a/sea-orm-sync/tests/partial_model_tests.rs b/sea-orm-sync/tests/partial_model_tests.rs index 359586b308..5d5de7c65d 100644 --- a/sea-orm-sync/tests/partial_model_tests.rs +++ b/sea-orm-sync/tests/partial_model_tests.rs @@ -130,6 +130,69 @@ fn partial_model_left_join_does_not_exist() { ctx.delete(); } +#[sea_orm_macros::test] +fn partial_model_left_join_with_optional_nested_model_optional_fields_does_not_exist() { + #[derive(Debug, DerivePartialModel, PartialEq)] + #[sea_orm(entity = "baker::Entity")] + struct BakerDetails { + id: i32, + name: String, + bakery_id: Option, + } + + #[derive(Debug, DerivePartialModel, PartialEq)] + #[sea_orm(entity = "baker::Entity")] + struct NestedBaker { + #[sea_orm(nested)] + details: BakerDetails, + } + + #[derive(Debug, DerivePartialModel, PartialEq)] + #[sea_orm(entity = "cake::Entity")] + struct CakeWithOptionalBakerModel { + id: i32, + name: String, + #[sea_orm(nested)] + baker: Option, + } + + let ctx = TestContext::new("partial_model_left_join_deep_baker"); + create_tables(&ctx.db).unwrap(); + + seed_data::init_1(&ctx, true); + + let cakes: Vec = cake::Entity::find() + .left_join(baker::Entity) + .order_by_asc(cake::Column::Id) + .into_partial_model() + .all(&ctx.db) + .expect("succeeds to get the result"); + + assert_eq!( + cakes, + [ + CakeWithOptionalBakerModel { + id: 13, + name: "Cheesecake".to_owned(), + baker: Some(NestedBaker { + details: BakerDetails { + id: 22, + name: "Master Baker".to_owned(), + bakery_id: Some(42), + }, + }), + }, + CakeWithOptionalBakerModel { + id: 15, + name: "Chocolate".to_owned(), + baker: None, + }, + ] + ); + + ctx.delete(); +} + #[sea_orm_macros::test] fn partial_model_left_join_exists() { let ctx = TestContext::new("partial_model_left_join_exists"); diff --git a/tests/partial_model_tests.rs b/tests/partial_model_tests.rs index 87be6d7e08..3a3f629049 100644 --- a/tests/partial_model_tests.rs +++ b/tests/partial_model_tests.rs @@ -131,6 +131,70 @@ async fn partial_model_left_join_does_not_exist() { ctx.delete().await; } +#[sea_orm_macros::test] +async fn partial_model_left_join_with_optional_nested_model_optional_fields_does_not_exist() { + #[derive(Debug, DerivePartialModel, PartialEq)] + #[sea_orm(entity = "baker::Entity")] + struct BakerDetails { + id: i32, + name: String, + bakery_id: Option, + } + + #[derive(Debug, DerivePartialModel, PartialEq)] + #[sea_orm(entity = "baker::Entity")] + struct NestedBaker { + #[sea_orm(nested)] + details: BakerDetails, + } + + #[derive(Debug, DerivePartialModel, PartialEq)] + #[sea_orm(entity = "cake::Entity")] + struct CakeWithOptionalBakerModel { + id: i32, + name: String, + #[sea_orm(nested)] + baker: Option, + } + + let ctx = TestContext::new("partial_model_left_join_deep_baker").await; + create_tables(&ctx.db).await.unwrap(); + + seed_data::init_1(&ctx, true).await; + + let cakes: Vec = cake::Entity::find() + .left_join(baker::Entity) + .order_by_asc(cake::Column::Id) + .into_partial_model() + .all(&ctx.db) + .await + .expect("succeeds to get the result"); + + assert_eq!( + cakes, + [ + CakeWithOptionalBakerModel { + id: 13, + name: "Cheesecake".to_owned(), + baker: Some(NestedBaker { + details: BakerDetails { + id: 22, + name: "Master Baker".to_owned(), + bakery_id: Some(42), + }, + }), + }, + CakeWithOptionalBakerModel { + id: 15, + name: "Chocolate".to_owned(), + baker: None, + }, + ] + ); + + ctx.delete().await; +} + #[sea_orm_macros::test] async fn partial_model_left_join_exists() { let ctx = TestContext::new("partial_model_left_join_exists").await;