diff --git a/crates/emmylua_code_analysis/resources/std/builtin.lua b/crates/emmylua_code_analysis/resources/std/builtin.lua index da7557002..af7194301 100644 --- a/crates/emmylua_code_analysis/resources/std/builtin.lua +++ b/crates/emmylua_code_analysis/resources/std/builtin.lua @@ -167,6 +167,13 @@ --- Extract from T those types that are assignable to U --- @alias Extract T extends U and T or never +--- +--- From T, pick a set of properties whose keys are in the union K +--- @alias Pick {[P in K]: T[P]; } + +--- +--- Construct a type with the properties of T except for those in type K. +--- @alias Omit Pick> --- attribute diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs index e3ad351d2..5f26acc62 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/file_generic_index.rs @@ -9,12 +9,24 @@ use crate::{GenericParam, GenericTpl, GenericTplId, LuaType}; pub trait GenericIndex: std::fmt::Debug { fn add_generic_scope(&mut self, ranges: Vec, is_func: bool) -> GenericScopeId; - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam); - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option; + + fn append_generic_params( + &mut self, + scope_id: GenericScopeId, + params: Vec, + ) -> Vec { + let mut appended = Vec::new(); for param in params { - self.append_generic_param(scope_id, param); + if let Some(tpl_id) = self.append_generic_param(scope_id, param.clone()) { + appended.push(param.with_tpl_id(Some(tpl_id))); + } } + appended } fn find_generic( @@ -63,16 +75,15 @@ impl GenericIndex for FileGenericIndex { scope_id } - fn append_generic_param(&mut self, scope_id: GenericScopeId, param: GenericParam) { + fn append_generic_param( + &mut self, + scope_id: GenericScopeId, + param: GenericParam, + ) -> Option { if let Some(scope) = self.scopes.get_mut(scope_id.id) { - scope.insert_param(param); - } - } - - fn append_generic_params(&mut self, scope_id: GenericScopeId, params: Vec) { - for param in params { - self.append_generic_param(scope_id, param); + return Some(scope.insert_param(param)); } + None } /// Find generic parameter by position and name. @@ -131,10 +142,12 @@ impl FileGenericScope { self.next_tpl_id.is_func() } - fn insert_param(&mut self, param: GenericParam) { - let tpl_id = self.next_tpl_id; - self.next_tpl_id = self.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); + fn insert_param(&mut self, param: GenericParam) -> GenericTplId { + let tpl_id = param.tpl_id.unwrap_or(self.next_tpl_id); + let next_idx = self.next_tpl_id.get_idx().max(tpl_id.get_idx() + 1) as u32; + self.next_tpl_id = self.next_tpl_id.with_idx(next_idx); self.params.insert(param.name.to_string(), (tpl_id, param)); + tpl_id } fn contains(&self, position: TextSize) -> bool { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs index d63ed8290..0fc9ad2b9 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_def_tags.rs @@ -142,9 +142,10 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< alias_decl.get_id() }; + let type_node = tag.get_type()?; if tag.get_generic_decl_list().is_some() { let generic_params = get_type_generic_params(analyzer, &alias_decl_id); - let range = analyzer.comment.get_range(); + let range = type_node.get_range(); let scope_id = analyzer .type_context .generic_index @@ -155,7 +156,7 @@ pub fn analyze_alias(analyzer: &mut DocAnalyzer, tag: LuaDocTagAlias) -> Option< .append_generic_params(scope_id, generic_params); } - let mut origin_type = infer_type(&mut analyzer.type_context, tag.get_type()?); + let mut origin_type = infer_type(&mut analyzer.type_context, type_node); if alias_origin_reaches(analyzer.get_db(), &origin_type, &alias_decl_id) { origin_type = LuaType::Any; } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs index ed8601db4..dcd26943f 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/type_generic_header.rs @@ -68,6 +68,7 @@ fn normalize_generic_params(db: &DbIndex, params: &[GenericParam]) -> Vec Option { + let scope = self.scopes.get_mut(scope_id.id)?; let tpl_id = scope.next_tpl_id; scope.next_tpl_id = scope.next_tpl_id.with_idx((tpl_id.get_idx() + 1) as u32); scope.params.push((tpl_id, param)); + Some(tpl_id) } fn find_generic( diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs index d6605830b..8a3565802 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/for_range_stat.rs @@ -3,7 +3,7 @@ use emmylua_parser::{LuaAstToken, LuaExpr, LuaForRangeStat}; use crate::{ DbIndex, InferFailReason, LuaDeclId, LuaInferCache, LuaOperatorMetaMethod, LuaType, LuaTypeCache, TplContext, TypeOps, TypeSubstitutor, VariadicType, - compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_doc_function, + compilation::analyzer::unresolve::UnResolveIterVar, infer_expr, instantiate_type_generic, tpl_pattern_match_args, }; @@ -145,22 +145,18 @@ pub fn infer_for_range_iter_expr_func( return Ok(doc_function.get_variadic_ret()); }; let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: None, - }; let params = doc_function .get_params() .iter() .map(|(_, opt_ty)| opt_ty.clone().unwrap_or(LuaType::Any)) .collect::>(); + let mut context = TplContext::new(db, cache, &mut substitutor, None); tpl_pattern_match_args(&mut context, ¶ms, &[status_param])?; + let doc_function_ty = LuaType::DocFunction(doc_function.clone()); let instantiate_func = if let LuaType::DocFunction(f) = - instantiate_doc_function(db, &doc_function, &substitutor) + instantiate_type_generic(db, &doc_function_ty, &substitutor) { f } else { diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 51d7983fd..99ced2d99 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -47,11 +47,13 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) break; }; + pre_analyze_call_arg_table_fields(analyzer, &expr); + match analyzer.infer_expr(&expr) { Ok(expr_type) => { let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); let decl_id = LuaDeclId::new(analyzer.file_id, position); - // 当`call`参数包含表时, 表可能未被分析, 需要延迟 + // 当表达式中存在带表参数的调用时, 表可能尚未完成预分析, 需要延迟 if let LuaType::Instance(instance) = &expr_type && instance.get_base().is_unknown() && call_expr_has_effect_table_arg(&expr).is_some() @@ -162,17 +164,7 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) } fn call_expr_has_effect_table_arg(expr: &LuaExpr) -> Option<()> { - if let LuaExpr::CallExpr(call_expr) = expr { - let args_list = call_expr.get_args_list()?; - for arg in args_list.get_args() { - if let LuaExpr::TableExpr(table_expr) = arg - && !table_expr.is_empty() - { - return Some(()); - } - } - } - None + expr_has_effect_table_call_arg(expr.clone()) } fn get_var_owner(analyzer: &mut LuaAnalyzer, var: LuaVarExpr) -> LuaTypeOwner { @@ -309,6 +301,8 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta continue; } + pre_analyze_call_arg_table_fields(analyzer, expr); + let expr_type = match analyzer.infer_expr(expr) { Ok(expr_type) => expr_type.get_result_slot_type(0).unwrap_or(expr_type), Err(InferFailReason::None) => LuaType::Unknown, @@ -529,8 +523,17 @@ pub fn analyze_table_field(analyzer: &mut LuaAnalyzer, field: LuaTableField) -> } } - let value_expr = field.get_value_expr()?; let member_id = LuaMemberId::new(field.get_syntax_id(), analyzer.file_id); + if analyzer + .db + .get_type_index() + .get_type_cache(&member_id.into()) + .is_some() + { + return Some(()); + } + let value_expr = field.get_value_expr()?; + let value_type = match analyzer.infer_expr(&value_expr.clone()) { Ok(value_type) => match value_type { LuaType::Def(ref_id) => LuaType::Ref(ref_id), @@ -620,3 +623,124 @@ fn get_delayed_definition_decl_id( } Some(decl_id) } + +fn pre_analyze_call_arg_table_fields(analyzer: &mut LuaAnalyzer, expr: &LuaExpr) { + pre_analyze_nested_table_fields(analyzer, expr.clone()); +} + +fn pre_analyze_nested_table_fields(analyzer: &mut LuaAnalyzer, expr: LuaExpr) { + match expr { + LuaExpr::CallExpr(call_expr) => { + if let Some(prefix_expr) = call_expr.get_prefix_expr() { + pre_analyze_nested_table_fields(analyzer, prefix_expr); + } + + if let Some(args_list) = call_expr.get_args_list() { + for arg in args_list.get_args() { + pre_analyze_nested_table_fields(analyzer, arg); + } + } + } + LuaExpr::TableExpr(table_expr) => { + for field in table_expr.get_fields() { + if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() { + pre_analyze_nested_table_fields(analyzer, key_expr); + } + + if let Some(value_expr) = field.get_value_expr() { + pre_analyze_nested_table_fields(analyzer, value_expr); + } + + analyze_table_field(analyzer, field.clone()); + } + } + LuaExpr::BinaryExpr(binary_expr) => { + if let Some((left, right)) = binary_expr.get_exprs() { + pre_analyze_nested_table_fields(analyzer, left); + pre_analyze_nested_table_fields(analyzer, right); + } + } + LuaExpr::UnaryExpr(unary_expr) => { + if let Some(inner_expr) = unary_expr.get_expr() { + pre_analyze_nested_table_fields(analyzer, inner_expr); + } + } + LuaExpr::ParenExpr(paren_expr) => { + if let Some(inner_expr) = paren_expr.get_expr() { + pre_analyze_nested_table_fields(analyzer, inner_expr); + } + } + LuaExpr::IndexExpr(index_expr) => { + if let Some(prefix_expr) = index_expr.get_prefix_expr() { + pre_analyze_nested_table_fields(analyzer, prefix_expr); + } + + if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { + pre_analyze_nested_table_fields(analyzer, key_expr); + } + } + LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => {} + } +} + +fn expr_has_effect_table_call_arg(expr: LuaExpr) -> Option<()> { + match expr { + LuaExpr::CallExpr(call_expr) => { + if let Some(prefix_expr) = call_expr.get_prefix_expr() + && expr_has_effect_table_call_arg(prefix_expr).is_some() + { + return Some(()); + } + + let args_list = call_expr.get_args_list()?; + for arg in args_list.get_args() { + if let LuaExpr::TableExpr(table_expr) = &arg + && !table_expr.is_empty() + { + return Some(()); + } + + if expr_has_effect_table_call_arg(arg).is_some() { + return Some(()); + } + } + None + } + LuaExpr::TableExpr(table_expr) => { + for field in table_expr.get_fields() { + if let Some(LuaIndexKey::Expr(key_expr)) = field.get_field_key() + && expr_has_effect_table_call_arg(key_expr).is_some() + { + return Some(()); + } + + if let Some(value_expr) = field.get_value_expr() + && expr_has_effect_table_call_arg(value_expr).is_some() + { + return Some(()); + } + } + None + } + LuaExpr::BinaryExpr(binary_expr) => { + let (left, right) = binary_expr.get_exprs()?; + expr_has_effect_table_call_arg(left).or_else(|| expr_has_effect_table_call_arg(right)) + } + LuaExpr::UnaryExpr(unary_expr) => expr_has_effect_table_call_arg(unary_expr.get_expr()?), + LuaExpr::ParenExpr(paren_expr) => expr_has_effect_table_call_arg(paren_expr.get_expr()?), + LuaExpr::IndexExpr(index_expr) => { + if let Some(prefix_expr) = index_expr.get_prefix_expr() + && expr_has_effect_table_call_arg(prefix_expr).is_some() + { + return Some(()); + } + + if let Some(LuaIndexKey::Expr(key_expr)) = index_expr.get_index_key() { + return expr_has_effect_table_call_arg(key_expr); + } + + None + } + LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) | LuaExpr::NameExpr(_) => None, + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs new file mode 100644 index 000000000..798c5e667 --- /dev/null +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_builtin.rs @@ -0,0 +1,237 @@ +#[cfg(test)] +mod test { + use crate::{DiagnosticCode, VirtualWorkspace}; + + #[test] + fn test_builtin_pick_preserves_selected_properties() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinPickUser + ---@field name string + ---@field age number + ---@field email string + ---@field nickname? string + + ---@type Pick + local picked + PickedName = picked.name + PickedAge = picked.age + PickedNickname = picked.nickname + + ---@type Pick + local pickedAll + PickedAllEmail = pickedAll.email + + ---@type Pick<{id: integer, enabled: boolean, label?: string}, "id" | "label"> + local pickedLiteral + PickedLiteralId = pickedLiteral.id + PickedLiteralLabel = pickedLiteral.label + "#, + ); + + assert_eq!(ws.expr_ty("PickedName"), ws.ty("string")); + assert_eq!(ws.expr_ty("PickedAge"), ws.ty("number")); + assert_eq!(ws.expr_ty("PickedNickname"), ws.ty("string?")); + assert_eq!(ws.expr_ty("PickedAllEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("PickedLiteralId"), ws.ty("integer")); + assert_eq!(ws.expr_ty("PickedLiteralLabel"), ws.ty("string?")); + } + + #[test] + fn test_builtin_pick_matches_ts6_key_constraint() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinPickConstraintUser + ---@field name string + ---@field age number + "#, + ); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Pick + local picked + local name = picked.name + "# + )); + } + + #[test] + fn test_builtin_pick_empty_keyof_domain_converges() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinEmptyPickClass + + ---@type Pick<{}, keyof {}> + local pickedEmptyObject + PickedEmptyObjectMissing = pickedEmptyObject.missing + + ---@type Pick + local pickedEmptyClass + PickedEmptyClassMissing = pickedEmptyClass.missing + "#, + ); + + assert_eq!(ws.expr_ty("PickedEmptyObjectMissing"), ws.ty("nil")); + assert_eq!(ws.expr_ty("PickedEmptyClassMissing"), ws.ty("nil")); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Pick<{}, keyof {}> + local picked + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Pick<{}, keyof {}> + local picked + local missing = picked.missing + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@class BuiltinEmptyPickDiagnosticClass + + ---@type Pick + local picked + local missing = picked.missing + "# + )); + } + + #[test] + fn test_builtin_omit_removes_properties() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinOmitUser + ---@field name string + ---@field age number + ---@field email string + ---@field nickname? string + + ---@type Omit + local omitted + OmittedName = omitted.name + OmittedAge = omitted.age + OmittedNickname = omitted.nickname + OmittedEmail = omitted.email + + ---@type Pick> + local pickedWithoutEmail + PickedWithoutEmailEmail = pickedWithoutEmail.email + + ---@type Omit<{id: integer, enabled: boolean, label?: string}, "enabled"> + local omittedLiteral + OmittedLiteralId = omittedLiteral.id + OmittedLiteralLabel = omittedLiteral.label + "#, + ); + + assert_eq!(ws.expr_ty("OmittedName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmittedAge"), ws.ty("number")); + assert_eq!(ws.expr_ty("OmittedNickname"), ws.ty("string?")); + assert_eq!(ws.expr_ty("PickedWithoutEmailEmail"), ws.ty("nil")); + assert_eq!(ws.expr_ty("OmittedEmail"), ws.ty("nil")); + assert_eq!(ws.expr_ty("OmittedLiteralId"), ws.ty("integer")); + assert_eq!(ws.expr_ty("OmittedLiteralLabel"), ws.ty("string?")); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local email = omitted.email + "# + )); + } + + #[test] + fn test_builtin_omit_matches_ts6_keyof_any_behavior() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@class BuiltinOmitKeyUser + ---@field name string + ---@field age number + ---@field email string + + ---@type Omit + local omitMissing + OmitMissingName = omitMissing.name + OmitMissingEmail = omitMissing.email + + ---@type Omit + local omitNever + OmitNeverName = omitNever.name + OmitNeverEmail = omitNever.email + + ---@type Omit + local omitAll + OmitAllName = omitAll.name + "#, + ); + + assert_eq!(ws.expr_ty("OmitMissingName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitMissingEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitNeverName"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitNeverEmail"), ws.ty("string")); + assert_eq!(ws.expr_ty("OmitAllName"), ws.ty("nil")); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@type Omit + local omitted + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local name = omitted.name + "# + )); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::UndefinedField, + r#" + ---@type Omit + local omitted + local name = omitted.name + "# + )); + } +} diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs index 5eef7210c..226513f3e 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_infer_test.rs @@ -47,6 +47,57 @@ mod test { assert_eq!(ws.humanize_type(a_ty), "string"); } + #[test] + fn test_object_literal_infer_nested_call_argument() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ExtractX T extends { x: infer X } and X or never + + ---@generic T + ---@param value T + ---@return ExtractX + function extractX(value) end + + ---@generic T + ---@param value T + ---@return T + function identity(value) end + + A = identity(extractX({ x = 1 })) + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(a_ty), "integer"); + } + + #[test] + fn test_object_literal_infer_nested_call_inside_table_field() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias ExtractX T extends { x: infer X } and X or never + ---@alias ExtractInner T extends { inner: infer I } and I or never + + ---@generic T + ---@param value T + ---@return ExtractX + function extractX(value) end + + ---@generic T + ---@param value T + ---@return ExtractInner + function extractInner(value) end + + B = extractInner({ inner = extractX({ x = 1 }) }) + "#, + ); + + let b_ty = ws.expr_ty("B"); + assert_eq!(ws.humanize_type(b_ty), "integer"); + } + #[test] fn test_object_literal_infer_from_class() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 3a1b462b9..3c99430a6 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -3,8 +3,8 @@ mod test { use emmylua_parser::LuaClosureExpr; use crate::{ - DiagnosticCode, LuaSignatureId, LuaType, LuaTypeDeclId, VirtualWorkspace, - complete_type_generic_args, + DiagnosticCode, GenericTplId, LuaSignatureId, LuaType, LuaTypeDeclId, TypeSubstitutor, + VirtualWorkspace, complete_type_generic_args, instantiate_type_generic, }; #[test] @@ -412,6 +412,172 @@ mod test { )); } + #[test] + fn test_keyof_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Keys keyof T + ---@alias ForwardKeys Keys + + ---@param key "a" | "b" + function accept(key) end + "#, + ); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type ForwardKeys<{ a: string, b: number }> + local key + accept(key) + "# + )); + } + + #[test] + fn test_mapped_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K]; } + ---@alias ForwardCopy Copy + + ---@type ForwardCopy<{ a: string, b: number }> + local copy + + A = copy.a + B = copy.b + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + assert_eq!(ws.expr_ty("B"), ws.ty("number")); + } + + #[test] + fn test_mapped_unresolved_key_domain_preserves_residual() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Copy { [K in keyof T]: T[K]; } + + ---@generic T + ---@param value Copy + ---@return Copy + function keep(value) end + + ---@type Copy<{ a: string }> + local concrete + + A = keep(concrete).a + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + } + + #[test] + fn test_alias_argument_binding_ignores_shadowing_function_generic() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Box fun(x: T): T + + ---@type Box + local f + + Result = f(1) + "#, + ); + + let result_ty = ws.expr_ty("Result"); + assert_eq!(ws.humanize_type(result_ty), "1"); + assert!(ws.has_no_diagnostic( + DiagnosticCode::ParamTypeMismatch, + r#" + ---@type Box + local f + + f(1) + "# + )); + } + + #[test] + fn test_alias_argument_binding_ignores_shadowing_mapped_key() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Shadow { [T in keyof { a: string }]: T; } + + ---@type Shadow + local value + + A = value.a + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty(r#""a""#)); + } + + #[test] + fn test_conditional_alias_residual_resolves_after_forwarding() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Extract T extends U and T or never + ---@alias KeepA Extract + ---@alias Forward KeepA + "#, + ); + + let generic_ty = ws.ty(r#"Forward<"a" | "b">"#); + let instantiated = + instantiate_type_generic(ws.get_db_mut(), &generic_ty, &TypeSubstitutor::new()); + assert_eq!(instantiated, ws.ty(r#""a""#)); + } + + #[test] + fn test_nested_mapped_conditional_alias_residual_resolves() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Wrapper + ---@alias UnwrapFields { [K in keyof T]: T[K] extends Wrapper and U or T[K]; } + ---@alias Forward UnwrapFields + + ---@type Forward<{ a: Wrapper, b: number }> + local value + + A = value.a + B = value.b + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + assert_eq!(ws.expr_ty("B"), ws.ty("number")); + } + + #[test] + fn test_recursive_alias_instantiation_budget_falls_back_safely() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Loop Loop + ---@alias Forward Loop + + ---@type Forward + local value + + Value = value + "#, + ); + + let value_ty = ws.expr_ty("Value"); + assert_eq!(ws.humanize_type(value_ty), "Forward"); + } + #[test] fn test_issue_787() { let mut ws = VirtualWorkspace::new(); @@ -494,9 +660,12 @@ mod test { A, B, C = f1(1, "2", true) "#, )); - assert_eq!(ws.expr_ty("A"), ws.ty("integer")); - assert_eq!(ws.expr_ty("B"), ws.ty("string")); - assert_eq!(ws.expr_ty("C"), ws.ty("boolean")); + let a_ty = ws.expr_ty("A"); + let b_ty = ws.expr_ty("B"); + let c_ty = ws.expr_ty("C"); + assert_eq!(ws.humanize_type(a_ty), "1"); + assert_eq!(ws.humanize_type(b_ty), "\"2\""); + assert_eq!(ws.humanize_type(c_ty), "true"); } { ws.def( @@ -533,7 +702,8 @@ mod test { G, H = f3(1, "2") "#, )); - assert_eq!(ws.expr_ty("G"), ws.ty("integer")); + let g_ty = ws.expr_ty("G"); + assert_eq!(ws.humanize_type(g_ty), "1"); assert_eq!(ws.expr_ty("H"), ws.ty("any")); } @@ -681,7 +851,7 @@ mod test { "#, )); let result_ty = ws.expr_ty("result"); - assert_eq!(ws.humanize_type(result_ty), "string"); + assert_eq!(ws.humanize_type(result_ty), "\"\""); } #[test] @@ -699,7 +869,7 @@ mod test { ); let result_ty = ws.expr_ty("result"); - assert_eq!(ws.humanize_type(result_ty), "integer"); + assert_eq!(ws.humanize_type(result_ty), "1"); } #[test] @@ -764,6 +934,7 @@ mod test { .expect("Box generic params"); assert_eq!(box_params.len(), 1); assert_eq!(box_params[0].name.as_str(), "T"); + assert_eq!(box_params[0].tpl_id, Some(GenericTplId::Type(0))); let box_default = box_params[0] .default_type .clone() @@ -779,6 +950,7 @@ mod test { .expect("Optional generic params"); assert_eq!(optional_params.len(), 1); assert_eq!(optional_params[0].name.as_str(), "T"); + assert_eq!(optional_params[0].tpl_id, Some(GenericTplId::Type(0))); let optional_default = optional_params[0] .default_type .clone() @@ -1075,7 +1247,7 @@ mod test { let explicit_result = ws.expr_ty("ExplicitResult"); assert_eq!(ws.humanize_type(explicit_result), "number"); let inferred_result = ws.expr_ty("InferredResult"); - assert_eq!(ws.humanize_type(inferred_result), "integer"); + assert_eq!(ws.humanize_type(inferred_result), "1"); } #[test] @@ -1217,7 +1389,7 @@ mod test { } #[test] - fn test_constant_decay() { + fn test_plain_tpl_literal_key_inference_widens_through_finalize() { let mut ws = VirtualWorkspace::new(); ws.def( r#" @@ -1250,6 +1422,207 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "integer"); } + #[test] + fn test_const_tpl_candidate_preserves_literal_through_plain_return() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias std.ConstTpl unknown + + ---@generic T + ---@param value std.ConstTpl + ---@return T + function keep_const(value) + end + + result = keep_const("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_plain_tpl_top_level_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + result = id("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_transparent_alias_top_level_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Id T + + ---@generic T + ---@param value T + ---@return Id + function id(value) + end + + result = id("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_transparent_alias_root_union_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@alias Id T + + ---@generic T + ---@param value T + ---@return Id|nil + function maybe(value) + end + + result = maybe("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), r#""mode"?"#); + } + + #[test] + fn test_plain_tpl_root_union_return_preserves_primitive_literal() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T|nil + function maybe(value) + end + + result = maybe("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), r#""mode"?"#); + } + + #[test] + fn test_plain_tpl_top_level_return_preserves_primitive_literal_union() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + ---@alias Choice "left" | "right" + + ---@type Choice + local choice + + result = id(choice) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("\"left\" | \"right\"")); + } + + #[test] + fn test_primitive_constraint_preserves_literal_candidate() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T: string + ---@param value T + ---@return T + function constrained(value) + end + + result = constrained("mode") + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(ws.humanize_type(result_ty), "\"mode\""); + } + + #[test] + fn test_finalized_table_const_self_reference_widens_without_recursing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + function id(value) + end + + local t = { kind = "mode" } + t.self = t + + result = id(t) + "#, + ); + + let result_ty = ws.expr_ty("result"); + assert_eq!(result_ty, ws.ty("{ kind: string, self: table }")); + } + + #[test] + fn test_contextual_widening_keeps_bare_literal_but_widens_nested_literals() { + use crate::{ + LuaMemberKey, LuaObjectType, WideningContext, WideningGuard, widen_type_with_context, + }; + use smol_str::SmolStr; + + let mut ws = VirtualWorkspace::new(); + let bare = LuaType::StringConst(SmolStr::new("mode").into()); + assert_eq!( + widen_type_with_context( + bare.clone(), + WideningContext::Root, + &mut WideningGuard::default() + ), + bare + ); + + let object = LuaType::Object( + LuaObjectType::new_with_fields( + [( + LuaMemberKey::Name("kind".into()), + LuaType::StringConst(SmolStr::new("mode").into()), + )] + .into_iter() + .collect(), + Vec::new(), + ) + .into(), + ); + let widened = + widen_type_with_context(object, WideningContext::Root, &mut WideningGuard::default()); + assert_eq!(widened, ws.ty("{ kind: string }")); + } + #[test] fn test_extends_true() { let mut ws = VirtualWorkspace::new(); @@ -1475,6 +1848,26 @@ mod test { assert_eq!(ws.humanize_type(result_ty), "number"); } + #[test] + fn test_distributed_function_generic_conditional_return_filters_union_members() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param value T + ---@return T extends string and T or never + function extractString(value) end + + ---@type string|integer + local value + + A = extractString(value) + "#, + ); + + assert_eq!(ws.expr_ty("A"), ws.ty("string")); + } + #[test] fn test_union_never() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/mod.rs b/crates/emmylua_code_analysis/src/compilation/test/mod.rs index a6c8629fd..3b6807bde 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/mod.rs @@ -10,6 +10,7 @@ mod decl_test; mod diagnostic_disable_test; mod flow; mod for_range_var_infer_test; +mod generic_builtin; mod generic_infer_test; mod generic_test; mod infer_str_tpl_test; diff --git a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs index 47ef6dfb4..606957d47 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/unpack_test.rs @@ -130,6 +130,44 @@ mod test { assert_eq!(ws.humanize_type(b_ty), "number"); } + #[test] + fn test_unpack_alias_call_uses_uninferred_generic_default() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@generic T = [string, number] + ---@return std.Unpack + function f() + end + + a, b = f() + "#, + ); + + let a_ty = ws.expr_ty("a"); + let b_ty = ws.expr_ty("b"); + assert_eq!(ws.humanize_type(a_ty), "string"); + assert_eq!(ws.humanize_type(b_ty), "number"); + } + + #[test] + fn test_unpack_alias_call_uses_uninferred_generic_constraint() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + ws.def( + r#" + ---@generic T: string[] + ---@return std.Unpack + function f() + end + + a = f() + "#, + ); + + let a_ty = ws.expr_ty("a"); + assert_eq!(ws.humanize_type(a_ty), "string?"); + } + #[test] fn test_issue_484() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs index 1a66b2031..08978d8be 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/generic_param.rs @@ -1,10 +1,11 @@ use smol_str::SmolStr; -use crate::{LuaAttributeUse, LuaType}; +use crate::{GenericTplId, LuaAttributeUse, LuaType}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct GenericParam { pub name: SmolStr, + pub tpl_id: Option, pub type_constraint: Option, pub default_type: Option, pub attributes: Option>, @@ -19,9 +20,15 @@ impl GenericParam { ) -> Self { Self { name, + tpl_id: None, type_constraint, default_type, attributes, } } + + pub fn with_tpl_id(mut self, tpl_id: Option) -> Self { + self.tpl_id = tpl_id; + self + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs index cc5594eff..c535f984a 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/generic/generic_constraint_mismatch.rs @@ -1,6 +1,6 @@ use emmylua_parser::{ LuaAst, LuaAstNode, LuaCallExpr, LuaClosureExpr, LuaComment, LuaDocGenericDeclList, - LuaDocTagAlias, LuaDocTagClass, LuaDocTagGeneric, LuaDocTagType, LuaDocType, + LuaDocGenericType, LuaDocTagAlias, LuaDocTagClass, LuaDocTagGeneric, LuaDocTagType, LuaDocType, }; use rowan::TextRange; use smol_str::SmolStr; @@ -13,7 +13,7 @@ use crate::semantic::{ use crate::{ DiagnosticCode, DocTypeInferContext, GenericTplId, LuaArrayType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaSignatureId, LuaStringTplType, LuaTupleType, LuaType, - LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, + LuaTypeNode, LuaUnionType, RenderLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, TypeSubstitutor, VariadicType, humanize_type, infer_doc_type, instantiate_type_generic, }; @@ -617,55 +617,142 @@ fn check_doc_tag_type( let type_list = doc_tag_type.get_type_list(); let doc_ctx = DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for doc_type in type_list { - let explicit_args = explicit_generic_args(&doc_type); - if explicit_args.is_empty() { - continue; - } + check_doc_type_generic_constraints(context, semantic_model, doc_ctx, &doc_type); + } + Some(()) +} - let type_ref = infer_doc_type(doc_ctx, &doc_type); - let generic_type = match type_ref { - LuaType::Generic(generic_type) => generic_type, - _ => continue, - }; +fn check_doc_type_generic_constraints( + context: &mut DiagnosticContext, + semantic_model: &SemanticModel, + doc_ctx: DocTypeInferContext<'_>, + doc_type: &LuaDocType, +) -> Option<()> { + let LuaDocType::Generic(generic_doc_type) = doc_type else { + return Some(()); + }; + + let explicit_args = explicit_generic_args(generic_doc_type); + if explicit_args.is_empty() { + return Some(()); + } - let generic_params = semantic_model + let name = generic_doc_type.get_name_type()?.get_name_text()?; + let type_decl = semantic_model.get_db().get_type_index().find_type_decl( + semantic_model.get_file_id(), + &name, + semantic_model .get_db() - .get_type_index() - .get_generic_params(&generic_type.get_base_type_id())?; - for (i, param_type) in generic_type - .get_params() - .iter() - .take(explicit_args.len()) - .enumerate() - { - let extend_type = generic_params.get(i)?.type_constraint.clone()?; - let result = semantic_model.type_check_detail(&extend_type, param_type); - if result.is_err() { - add_type_check_diagnostic( - context, - semantic_model, - explicit_args.get(i)?.get_range(), - &extend_type, - param_type, - result, - ); + .resolve_workspace_id(semantic_model.get_file_id()), + )?; + let type_id = type_decl.get_id(); + let generic_params = semantic_model + .get_db() + .get_type_index() + .get_generic_params(&type_id)?; + + let instantiate_arg = explicit_arg_instantiation_flags(&generic_params, explicit_args.len()); + let empty_substitutor = TypeSubstitutor::new(); + let param_types = explicit_args + .iter() + .enumerate() + .map(|(idx, doc_type)| { + let ty = infer_doc_type(doc_ctx, doc_type); + if instantiate_arg.get(idx).copied().unwrap_or(false) { + instantiate_type_generic(semantic_model.get_db(), &ty, &empty_substitutor) + } else { + ty } + }) + .collect::>(); + + let substitutor = + TypeSubstitutor::from_alias(semantic_model.get_db(), param_types.clone(), type_id); + + for (i, param_type) in param_types.iter().enumerate() { + let Some(explicit_arg) = explicit_args.get(i) else { + continue; + }; + let Some(extend_type) = generic_params + .get(i) + .and_then(|param| param.type_constraint.clone()) + else { + continue; + }; + + let mut extend_type = + instantiate_type_generic(semantic_model.get_db(), &extend_type, &substitutor); + extend_type = normalize_keyof_any_constraint(extend_type); + let result = semantic_model.type_check_detail(&extend_type, param_type); + if result.is_err() { + add_type_check_diagnostic( + context, + semantic_model, + explicit_arg.get_range(), + &extend_type, + param_type, + result, + ); } } + Some(()) } -fn explicit_generic_args(doc_type: &LuaDocType) -> Vec { - let LuaDocType::Generic(generic_doc_type) = doc_type else { - return Vec::new(); - }; - +fn explicit_generic_args(generic_doc_type: &LuaDocGenericType) -> Vec { generic_doc_type .get_generic_types() .map(|type_list| type_list.get_types().collect()) .unwrap_or_default() } +fn explicit_arg_instantiation_flags( + generic_params: &[crate::GenericParam], + explicit_arg_count: usize, +) -> Vec { + let mut flags = vec![false; explicit_arg_count]; + for (constraint_index, param) in generic_params.iter().enumerate().take(explicit_arg_count) { + let Some(constraint) = param.type_constraint.as_ref() else { + continue; + }; + + flags[constraint_index] = true; + for (arg_index, referenced_param) in + generic_params.iter().enumerate().take(explicit_arg_count) + { + let tpl_id = referenced_param + .tpl_id + .unwrap_or(GenericTplId::Type(arg_index as u32)); + if type_contains_tpl_ref(constraint, tpl_id) { + flags[arg_index] = true; + } + } + } + + flags +} + +fn type_contains_tpl_ref(ty: &LuaType, tpl_id: GenericTplId) -> bool { + ty.any_type(|ty| match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::StrTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + _ => false, + }) +} + +fn normalize_keyof_any_constraint(ty: LuaType) -> LuaType { + match ty { + LuaType::Call(alias_call) + if alias_call.get_call_kind() == crate::LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 + && alias_call.get_operands()[0].is_any() => + { + LuaType::from_vec(vec![LuaType::String, LuaType::Integer, LuaType::Number]) + } + _ => ty, + } +} + #[allow(clippy::too_many_arguments)] fn check_param( context: &mut DiagnosticContext, diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index 1ad1c18ef..e55590b35 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -33,7 +33,7 @@ pub fn build_call_constraint_context( let mut substitutor = TypeSubstitutor::new(); let generic_tpls = collect_func_tpl_ids(¶ms); if !generic_tpls.is_empty() { - substitutor.add_need_infer_tpls(generic_tpls); + substitutor.prepare_inference_slots(generic_tpls); } // 读取显式传入的泛型实参 @@ -42,7 +42,7 @@ pub fn build_call_constraint_context( DocTypeInferContext::new(semantic_model.get_db(), semantic_model.get_file_id()); for (idx, doc_type) in type_list.get_types().enumerate() { let ty = infer_doc_type(doc_ctx, &doc_type); - substitutor.insert_type(GenericTplId::Func(idx as u32), ty, true); + substitutor.bind_type(GenericTplId::Func(idx as u32), ty); } } @@ -261,16 +261,16 @@ fn record_generic_assignment( match param_type { LuaType::TplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), true); + substitutor.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::ConstTplRef(tpl_ref) => { if !tpl_ref.get_tpl_id().is_conditional_infer() { - substitutor.insert_type(tpl_ref.get_tpl_id(), arg_type.clone(), false); + substitutor.bind_type(tpl_ref.get_tpl_id(), arg_type.clone()); } } LuaType::StrTplRef(str_tpl_ref) => { - substitutor.insert_type(str_tpl_ref.get_tpl_id(), arg_type.clone(), true); + substitutor.bind_type(str_tpl_ref.get_tpl_id(), arg_type.clone()); } LuaType::Variadic(variadic) => { if let Some(inner) = variadic.get_type(0) { diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 8ac1a2644..de85aa424 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -96,7 +96,7 @@ fn complete_type_generic_args_inner( for (idx, generic_param) in generic_params.iter().enumerate() { if let Some(provided_arg) = provided_args.get(idx) { let provided_arg = provided_arg.clone(); - substitutor.insert_type(GenericTplId::Type(idx as u32), provided_arg.clone(), true); + substitutor.bind_type(GenericTplId::Type(idx as u32), provided_arg.clone()); params.push(provided_arg); continue; } @@ -115,7 +115,7 @@ fn complete_type_generic_args_inner( completed_type.ty }; let instantiated = instantiate_type_generic(db, &default_type, &substitutor); - substitutor.insert_type(GenericTplId::Type(idx as u32), instantiated.clone(), true); + substitutor.bind_type(GenericTplId::Type(idx as u32), instantiated.clone()); params.push(instantiated); } else { missing_required_count += 1; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs similarity index 85% rename from crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs rename to crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs index 1c8e39996..b58290125 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/infer_call_func_generic.rs @@ -12,7 +12,6 @@ use crate::{ semantic::{ LuaInferCache, generic::{ - instantiate_type::instantiate_doc_function, tpl_context::TplContext, tpl_pattern::{ multi_param_tpl_pattern_match_multi_return, return_type_pattern_match_target_type, @@ -21,18 +20,19 @@ use crate::{ }, infer::InferFailReason, infer_expr, - overload_resolve::{callable_accepts_args, resolve_signature_by_args}, + overload_resolve::{ + callable_accepts_args, collect_callable_overload_groups, resolve_signature_by_args, + }, }, }; use crate::{ GenericTpl, LuaMemberOwner, LuaSemanticDeclId, LuaTypeOwner, SemanticDeclLevel, TypeVisitTrait, - collect_callable_overload_groups, infer_node_semantic_decl, - tpl_pattern_match_args_skip_unknown, + infer_node_semantic_decl, tpl_pattern_match_args_skip_unknown, }; use super::{TypeSubstitutor, instantiate_type_generic}; -pub fn instantiate_func_generic( +pub fn infer_call_func_generic( db: &DbIndex, cache: &mut LuaInferCache, func: &LuaFunctionType, @@ -53,36 +53,35 @@ pub fn instantiate_func_generic( .get_args() .collect::>(); let mut substitutor = TypeSubstitutor::new(); - let mut context = TplContext { - db, - cache, - substitutor: &mut substitutor, - call_expr: Some(call_expr.clone()), - }; - if !generic_tpls.is_empty() { - context.substitutor.add_need_infer_tpls(generic_tpls); + { + let mut context = TplContext::new(db, cache, &mut substitutor, Some(call_expr.clone())); + if !generic_tpls.is_empty() { + context.substitutor.prepare_inference_slots(generic_tpls); - if let Some(type_list) = call_expr.get_call_generic_type_list() { - // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 - apply_call_generic_type_list(db, file_id, &mut context, &type_list); - } else { - // 如果没有指定泛型, 则需要从调用参数中推断 - infer_generic_types_from_call( - db, - &mut context, - func, - &call_expr, - &mut func_params, - &arg_exprs, - )?; + if let Some(type_list) = call_expr.get_call_generic_type_list() { + // 如果使用了`obj:abc--[[@]]("abc")`强制指定了泛型, 那么我们只需要直接应用 + apply_call_generic_type_list(db, file_id, &mut context, &type_list); + } else { + // 如果没有指定泛型, 则需要从调用参数中推断 + infer_generic_types_from_call( + db, + &mut context, + func, + &call_expr, + &mut func_params, + &arg_exprs, + )?; + } } } if contain_self && let Some(self_type) = infer_self_type(db, cache, &call_expr) { substitutor.add_self_type(self_type); } + substitutor.finalize_inferred_types(db, func_generic_tpls(func).iter(), func.get_ret()); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, func, &substitutor) { + let func_ty = LuaType::DocFunction(func.clone().into()); + if let LuaType::DocFunction(f) = instantiate_type_generic(db, &func_ty, &substitutor) { Ok(f.deref().clone()) } else { Ok(func.clone()) @@ -100,11 +99,11 @@ fn apply_call_generic_type_list( let typ = infer_doc_type(doc_ctx, &doc_type); context .substitutor - .insert_type(GenericTplId::Func(i as u32), typ, true); + .bind_type(GenericTplId::Func(i as u32), typ); } } -pub fn as_doc_function_type( +fn as_doc_function_type( db: &DbIndex, callable_type: &LuaType, ) -> Result>, InferFailReason> { @@ -197,7 +196,7 @@ fn uses_erased_function_param(callable: &LuaFunctionType, call_arg_types: &[LuaT }) } -pub fn infer_callable_return_from_remaining_args( +fn infer_callable_return_from_remaining_args( context: &mut TplContext, callable_type: &LuaType, arg_exprs: &[LuaExpr], @@ -252,27 +251,36 @@ fn instantiate_callable_from_arg_types( .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls.clone()); - let mut callable_context = TplContext { - db: context.db, - cache: context.cache, - substitutor: &mut callable_substitutor, - call_expr: context.call_expr.clone(), - }; - if tpl_pattern_match_args_skip_unknown( - &mut callable_context, - &callable_param_types, - call_arg_types, - ) - .is_err() + callable_substitutor.prepare_inference_slots(callable_tpls.clone()); { - return None; + let mut callable_context = TplContext::new( + context.db, + context.cache, + &mut callable_substitutor, + context.call_expr.clone(), + ); + if tpl_pattern_match_args_skip_unknown( + &mut callable_context, + &callable_param_types, + call_arg_types, + ) + .is_err() + { + return None; + } } - let instantiated = match instantiate_doc_function(context.db, callable, &callable_substitutor) { - LuaType::DocFunction(func) => func, - _ => callable.clone(), - }; + callable_substitutor.finalize_inferred_types( + context.db, + callable_generic_tpls(callable).iter(), + callable.get_ret(), + ); + let callable_ty = LuaType::DocFunction(callable.clone()); + let instantiated = + match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + }; let unresolved_return_tpls = { let mut tpl_ids = HashSet::new(); instantiated.get_ret().visit_type(&mut |ty| { @@ -299,9 +307,10 @@ fn instantiate_callable_from_arg_types( } for tpl_id in callback_return_tpls { - callable_substitutor.insert_type(tpl_id, LuaType::Unknown, true); + callable_substitutor.bind_type(tpl_id, LuaType::Unknown); } - match instantiate_doc_function(context.db, callable, &callable_substitutor) { + let callable_ty = LuaType::DocFunction(callable.clone()); + match instantiate_type_generic(context.db, &callable_ty, &callable_substitutor) { LuaType::DocFunction(func) => Some(func), _ => None, } @@ -377,6 +386,27 @@ fn collect_func_tpl_ids(func: &LuaFunctionType) -> (HashSet, bool) (generic_tpls, contain_self) } +fn func_generic_tpls(func: &LuaFunctionType) -> Vec> { + let mut generic_tpls = Vec::new(); + func.visit_nested_types(&mut |ty| match ty { + LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) => { + if generic_tpl.get_tpl_id().is_func() + && !generic_tpls + .iter() + .any(|it: &Arc| it.get_tpl_id() == generic_tpl.get_tpl_id()) + { + generic_tpls.push(generic_tpl.clone()); + } + } + _ => {} + }); + generic_tpls +} + +fn callable_generic_tpls(callable: &LuaFunctionType) -> Vec> { + func_generic_tpls(callable) +} + fn collect_func_tpl_with_fallback_deps( generic_tpl: &GenericTpl, generic_tpls: &mut HashSet, @@ -488,7 +518,7 @@ fn infer_generic_types_from_call( break; } - if context.substitutor.is_infer_all_tpl() { + if !context.substitutor.has_unresolved_inference_slots() { break; } @@ -511,7 +541,6 @@ fn infer_generic_types_from_call( Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 Err(e) => return Err(e), }; - if let Some(return_pattern) = as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) { @@ -553,7 +582,7 @@ fn infer_generic_types_from_call( } } - if !context.substitutor.is_infer_all_tpl() { + if context.substitutor.has_unresolved_inference_slots() { for (func_param_type, call_arg_expr) in unresolve_tpls { let closure_type = infer_expr(db, context.cache, call_arg_expr)?; @@ -573,7 +602,7 @@ pub fn build_self_type(db: &DbIndex, self_type: &LuaType) -> LuaType { for (i, generic_param) in generic.iter().enumerate() { let tpl_id = GenericTplId::Type(i as u32); let param = build_self_generic_arg(db, generic_param, &substitutor); - substitutor.insert_type(tpl_id, param.clone(), true); + substitutor.bind_type(tpl_id, param.clone()); params.push(param); } let generic = LuaGenericType::new(id.clone(), params); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs new file mode 100644 index 000000000..0764615b9 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/inference_widening.rs @@ -0,0 +1,506 @@ +use std::{ops::Deref, sync::Arc}; + +use hashbrown::{HashMap, HashSet}; +use rowan::TextRange; + +use crate::{ + DbIndex, GenericParam, GenericTpl, InFiled, LuaArrayType, LuaConditionalType, LuaFunctionType, + LuaGenericType, LuaMappedType, LuaMemberKey, LuaMemberOwner, LuaObjectType, LuaTupleType, + LuaType, LuaUnionType, TypeOps, TypeSubstitutor, VariadicType, instantiate_type_generic, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(in crate::semantic::generic) enum TplCandidateSource { + Plain, + ConstPreserving, + Finalized, +} + +pub(in crate::semantic::generic) fn finalize_inferred_tpl_candidate( + db: &DbIndex, + tpl: &GenericTpl, + raw_candidate: &LuaType, + candidate_source: TplCandidateSource, + top_level: bool, + return_top_level: bool, + substitutor: &TypeSubstitutor, +) -> LuaType { + if candidate_source == TplCandidateSource::ConstPreserving { + return raw_candidate.clone(); + } + + let primitive_constraint = tpl + .get_constraint() + .map(|constraint| { + let constraint = instantiate_type_generic(db, constraint, substitutor); + is_primitive_or_literal_type(&constraint) + }) + .unwrap_or(false); + let candidate = if primitive_constraint || !top_level || return_top_level { + raw_candidate.clone() + } else { + widen_literal_type(raw_candidate.clone()) + }; + finalize_tpl_candidate_type( + db, + candidate, + WideningContext::Root, + &mut WideningGuard::default(), + ) +} + +fn is_primitive_or_literal_type(ty: &LuaType) -> bool { + match ty { + LuaType::String + | LuaType::Number + | LuaType::Integer + | LuaType::Boolean + | LuaType::StringConst(_) + | LuaType::DocStringConst(_) + | LuaType::IntegerConst(_) + | LuaType::DocIntegerConst(_) + | LuaType::FloatConst(_) + | LuaType::BooleanConst(_) + | LuaType::DocBooleanConst(_) => true, + LuaType::Tuple(tuple) => tuple.get_types().iter().any(is_primitive_or_literal_type), + LuaType::Union(union) => union.into_vec().iter().any(is_primitive_or_literal_type), + LuaType::MultiLineUnion(union) => union + .get_unions() + .iter() + .any(|(ty, _)| is_primitive_or_literal_type(ty)), + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => is_primitive_or_literal_type(base), + VariadicType::Multi(types) => types.iter().any(is_primitive_or_literal_type), + }, + LuaType::Call(call) => call.get_operands().iter().any(is_primitive_or_literal_type), + _ => false, + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WideningContext { + Root, + UnionMember, + ObjectProperty, + ArrayElement, + TupleElement, + VariadicElement, +} + +const MAX_WIDENING_DEPTH: u16 = 100; + +#[derive(Default)] +pub struct WideningGuard { + depth: u16, + active_table_ids: HashSet>, +} + +impl WideningGuard { + fn enter_level(&mut self) -> bool { + if self.depth >= MAX_WIDENING_DEPTH { + return false; + } + self.depth += 1; + true + } + + fn leave_level(&mut self) { + self.depth = self.depth.saturating_sub(1); + } + + fn enter_table(&mut self, table_id: &InFiled) -> bool { + self.active_table_ids.insert(table_id.clone()) + } + + fn leave_table(&mut self, table_id: &InFiled) { + self.active_table_ids.remove(table_id); + } +} + +fn finalize_tpl_candidate_type( + db: &DbIndex, + ty: LuaType, + context: WideningContext, + guard: &mut WideningGuard, +) -> LuaType { + if !guard.enter_level() { + return match ty { + LuaType::TableConst(_) => LuaType::Table, + ty => widen_literals_with_context(ty, context), + }; + } + + let widened = match ty { + LuaType::TableConst(table_id) => { + table_const_to_object(db, table_id, guard).unwrap_or(LuaType::Table) + } + LuaType::Object(object) => { + let fields = object + .get_fields() + .iter() + .map(|(key, ty)| { + ( + key.clone(), + finalize_tpl_candidate_type( + db, + ty.clone(), + WideningContext::ObjectProperty, + guard, + ), + ) + }) + .collect(); + let index_access = object + .get_index_access() + .iter() + .map(|(key, value)| { + ( + widen_type_with_context( + key.clone(), + WideningContext::ObjectProperty, + guard, + ), + finalize_tpl_candidate_type( + db, + value.clone(), + WideningContext::ObjectProperty, + guard, + ), + ) + }) + .collect(); + LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) + } + LuaType::Array(array) => { + let element_context = match context { + WideningContext::TupleElement => WideningContext::TupleElement, + _ => WideningContext::ArrayElement, + }; + let base = + finalize_tpl_candidate_type(db, array.get_base().clone(), element_context, guard); + LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) + } + LuaType::Tuple(tuple) => { + let types = tuple + .get_types() + .iter() + .cloned() + .map(|ty| finalize_tpl_candidate_type(db, ty, WideningContext::TupleElement, guard)) + .collect(); + LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + LuaType::Union(union) => { + let member_context = if matches!(context, WideningContext::Root) { + WideningContext::Root + } else { + WideningContext::UnionMember + }; + LuaType::Union( + LuaUnionType::from_vec( + union + .into_vec() + .into_iter() + .map(|ty| finalize_tpl_candidate_type(db, ty, member_context, guard)) + .collect(), + ) + .into(), + ) + } + ty => widen_type_with_context(ty, context, guard), + }; + + guard.leave_level(); + widened +} + +pub fn widen_type_with_context( + ty: LuaType, + context: WideningContext, + guard: &mut WideningGuard, +) -> LuaType { + if !guard.enter_level() { + return widen_literals_with_context(ty, context); + } + + let ty = widen_literals_with_context(ty, context); + + let widened = match ty { + LuaType::Array(array) => { + let element_context = match context { + WideningContext::TupleElement => WideningContext::TupleElement, + _ => WideningContext::ArrayElement, + }; + let base = widen_type_with_context(array.get_base().clone(), element_context, guard); + LuaType::Array(LuaArrayType::new(base, array.get_len().clone()).into()) + } + LuaType::Tuple(tuple) => { + let types = tuple + .get_types() + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::TupleElement, guard)) + .collect(); + LuaType::Tuple(LuaTupleType::new(types, tuple.status).into()) + } + LuaType::Object(object) => { + let fields = object + .get_fields() + .iter() + .map(|(key, ty)| { + ( + key.clone(), + widen_type_with_context(ty.clone(), WideningContext::ObjectProperty, guard), + ) + }) + .collect(); + let index_access = object + .get_index_access() + .iter() + .map(|(key, value)| { + ( + widen_type_with_context( + key.clone(), + WideningContext::ObjectProperty, + guard, + ), + widen_type_with_context( + value.clone(), + WideningContext::ObjectProperty, + guard, + ), + ) + }) + .collect(); + LuaType::Object(LuaObjectType::new_with_fields(fields, index_access).into()) + } + LuaType::Union(union) => { + let member_context = if matches!(context, WideningContext::Root) { + WideningContext::Root + } else { + WideningContext::UnionMember + }; + LuaType::Union( + LuaUnionType::from_vec( + union + .into_vec() + .into_iter() + .map(|ty| widen_type_with_context(ty, member_context, guard)) + .collect(), + ) + .into(), + ) + } + LuaType::MultiLineUnion(multi) => LuaType::MultiLineUnion( + crate::LuaMultiLineUnion::new( + multi + .get_unions() + .iter() + .map(|(ty, description)| { + ( + widen_type_with_context( + ty.clone(), + WideningContext::UnionMember, + guard, + ), + description.clone(), + ) + }) + .collect(), + ) + .into(), + ), + LuaType::Intersection(intersection) => LuaType::Intersection( + crate::LuaIntersectionType::new( + intersection + .get_types() + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::UnionMember, guard)) + .collect(), + ) + .into(), + ), + LuaType::Variadic(variadic) => LuaType::Variadic( + match variadic.deref() { + VariadicType::Base(base) => VariadicType::Base(widen_type_with_context( + base.clone(), + WideningContext::VariadicElement, + guard, + )), + VariadicType::Multi(types) => VariadicType::Multi( + types + .iter() + .cloned() + .map(|ty| { + widen_type_with_context(ty, WideningContext::VariadicElement, guard) + }) + .collect(), + ), + } + .into(), + ), + LuaType::Generic(generic) => LuaType::Generic( + LuaGenericType::new( + generic.get_base_type_id(), + generic + .get_params() + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)) + .collect(), + ) + .into(), + ), + LuaType::TableGeneric(params) => LuaType::TableGeneric( + params + .iter() + .cloned() + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)) + .collect::>() + .into(), + ), + LuaType::DocFunction(func) => LuaType::DocFunction( + LuaFunctionType::new( + func.get_async_state(), + func.is_colon_define(), + func.is_variadic(), + func.get_params() + .iter() + .map(|(name, ty)| { + ( + name.clone(), + ty.clone().map(|ty| { + widen_type_with_context(ty, WideningContext::Root, guard) + }), + ) + }) + .collect(), + widen_type_with_context(func.get_ret().clone(), WideningContext::Root, guard), + ) + .into(), + ), + LuaType::TypeGuard(type_guard) => LuaType::TypeGuard( + widen_type_with_context(type_guard.deref().clone(), WideningContext::Root, guard) + .into(), + ), + LuaType::Conditional(conditional) => LuaType::Conditional( + LuaConditionalType::new( + widen_type_with_context( + conditional.get_checked_type().clone(), + WideningContext::Root, + guard, + ), + widen_type_with_context( + conditional.get_extends_type().clone(), + WideningContext::Root, + guard, + ), + widen_type_with_context( + conditional.get_true_type().clone(), + WideningContext::Root, + guard, + ), + widen_type_with_context( + conditional.get_false_type().clone(), + WideningContext::Root, + guard, + ), + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ), + LuaType::Mapped(mapped) => LuaType::Mapped(Arc::new(LuaMappedType::new( + ( + mapped.param.0, + GenericParam::new( + mapped.param.1.name.clone(), + mapped + .param + .1 + .type_constraint + .clone() + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)), + mapped + .param + .1 + .default_type + .clone() + .map(|ty| widen_type_with_context(ty, WideningContext::Root, guard)), + mapped.param.1.attributes.clone(), + ), + ), + widen_type_with_context(mapped.value.clone(), WideningContext::Root, guard), + mapped.is_readonly, + mapped.is_optional, + ))), + ty => ty, + }; + + guard.leave_level(); + widened +} + +fn widen_literals_with_context(ty: LuaType, context: WideningContext) -> LuaType { + match context { + WideningContext::Root => ty, + _ => widen_literal_type(ty), + } +} + +fn widen_literal_type(ty: LuaType) -> LuaType { + match ty { + LuaType::FloatConst(_) => LuaType::Number, + LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, + LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, + LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, + ty => ty, + } +} + +fn table_const_to_object( + db: &DbIndex, + table_id: InFiled, + guard: &mut WideningGuard, +) -> Option { + let owner = LuaMemberOwner::Element(table_id.clone()); + let members = db.get_member_index().get_members(&owner)?; + if !guard.enter_table(&table_id) { + return Some(LuaType::Table); + } + let mut fields = HashMap::new(); + let mut index_access = Vec::new(); + + for member in members { + let value = db + .get_type_index() + .get_type_cache(&member.get_id().into()) + .map(|cache| cache.as_type().clone()) + .unwrap_or(LuaType::Unknown); + let value = finalize_tpl_candidate_type(db, value, WideningContext::ObjectProperty, guard); + + match member.get_key() { + LuaMemberKey::Name(_) | LuaMemberKey::Integer(_) => { + fields + .entry(member.get_key().clone()) + .and_modify(|prev| { + *prev = TypeOps::Union.apply(db, prev, &value); + }) + .or_insert(value); + } + LuaMemberKey::ExprType(key) => { + index_access.push(( + widen_type_with_context(key.clone(), WideningContext::ObjectProperty, guard), + value, + )); + } + LuaMemberKey::None => {} + } + } + + guard.leave_table(&table_id); + + Some(LuaType::Object( + LuaObjectType::new_with_fields(fields, index_access).into(), + )) +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index 226a0fb4f..bfb24ad50 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -1,23 +1,24 @@ use hashbrown::{HashMap, HashSet}; -use std::ops::Deref; +use internment::ArcIntern; use crate::{ - DbIndex, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, + DbIndex, GenericTpl, GenericTplId, LuaConditionalType, LuaTypeDeclId, LuaTypeNode, TypeOps, check_type_compact, db_index::{LuaObjectType, LuaTupleType, LuaType}, semantic::{member::find_members_with_key, type_check::check_type_compact_with_level}, }; use super::{ - get_default_constructor, instantiate_type_generic, instantiate_type_generic_with_context, + TplCandidateSource, finalize_inferred_tpl_candidate, get_default_constructor, + instantiate_type_generic_inner, +}; +use crate::semantic::generic::type_substitutor::{ + GenericInstantiateContext, GenericInstantiateFrame, TplBinding, }; -use crate::semantic::generic::type_substitutor::GenericInstantiateContext; #[derive(Debug, Clone, Copy)] enum InferVariance { - // 协变 Covariant, - // 逆变 Contravariant, } @@ -38,27 +39,35 @@ struct InferCandidateSet { pub(super) fn instantiate_conditional( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> LuaType { - if let Some(distributed) = instantiate_distributed_conditional(context, conditional) { + let Some(frame) = frame.enter() else { + return instantiate_conditional_residual(context, frame, conditional, None, None); + }; + + if let Some(distributed) = instantiate_distributed_conditional(context, frame, conditional) { return distributed; } - instantiate_conditional_once(context, conditional) + instantiate_conditional_once(context, frame, conditional) } fn instantiate_conditional_once( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> LuaType { let left_type = instantiate_conditional_operand( context, + frame, conditional.get_checked_type(), true, conditional.has_new, ); let right_type = instantiate_conditional_operand( context, + frame, conditional.get_extends_type(), false, conditional.has_new, @@ -78,115 +87,167 @@ fn instantiate_conditional_once( ) { instantiate_true_branch( context, + frame, conditional, - finalize_infer_assignments(infer_assignments), + finalize_infer_assignments(context, conditional, infer_assignments), ) - } else { - instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, + } else if is_deferred_conditional_operand(&left_type) + || right_type.any_type(|inner| match inner { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + !tpl.get_tpl_id().is_conditional_infer() + } + LuaType::StrTplRef(_) + | LuaType::SelfInfer + | LuaType::Conditional(_) + | LuaType::Mapped(_) + | LuaType::Call(_) => true, + _ => false, + }) + { + instantiate_conditional_residual( + context, + frame, + conditional, + Some(left_type), + Some(right_type), ) + } else { + instantiate_type_generic_inner(context, frame, conditional.get_false_type()) }; } match check_conditional_extends(context.db, &left_type, &right_type) { - ConditionalCheck::True => instantiate_true_branch(context, conditional, HashMap::new()), - ConditionalCheck::False => instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, - ), + ConditionalCheck::True => { + instantiate_true_branch(context, frame, conditional, HashMap::new()) + } + ConditionalCheck::False => { + instantiate_type_generic_inner(context, frame, conditional.get_false_type()) + } ConditionalCheck::Both => { - let true_type = instantiate_true_branch(context, conditional, HashMap::new()); - let false_type = instantiate_type_generic( - context.db, - conditional.get_false_type(), - context.substitutor, - ); + if is_deferred_conditional_operand(&left_type) + || is_deferred_conditional_operand(&right_type) + { + return instantiate_conditional_residual( + context, + frame, + conditional, + Some(left_type), + Some(right_type), + ); + } + let true_type = instantiate_true_branch(context, frame, conditional, HashMap::new()); + let false_type = + instantiate_type_generic_inner(context, frame, conditional.get_false_type()); TypeOps::Union.apply(context.db, &true_type, &false_type) } } } +fn instantiate_conditional_residual( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + conditional: &LuaConditionalType, + checked_type: Option, + extends_type: Option, +) -> LuaType { + let instantiate_branch = |branch: &LuaType| { + if branch.any_type(|ty| match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { + context.substitutor.get(tpl.get_tpl_id()).is_some() + } + LuaType::SelfInfer => context.substitutor.get_self_type().is_some(), + _ => false, + }) { + instantiate_type_generic_inner(context, frame, branch) + } else { + branch.clone() + } + }; + + LuaType::Conditional( + LuaConditionalType::new( + checked_type.unwrap_or_else(|| { + instantiate_type_generic_inner(context, frame, conditional.get_checked_type()) + }), + extends_type.unwrap_or_else(|| { + instantiate_type_generic_inner(context, frame, conditional.get_extends_type()) + }), + instantiate_branch(conditional.get_true_type()), + instantiate_branch(conditional.get_false_type()), + conditional.get_infer_params().to_vec(), + conditional.has_new, + ) + .into(), + ) +} + /// 处理分布式条件类型, 与`TS`中的分布式条件类型处理方式相同, 只有裸模版参数才会被分布式. fn instantiate_distributed_conditional( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, ) -> Option { - let tpl_id = naked_checked_type_tpl_id(conditional.get_checked_type())?; + let tpl_id = match conditional.get_checked_type() { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if tpl.get_tpl_id().is_type() || tpl.get_tpl_id().is_func() => + { + tpl.get_tpl_id() + } + _ => return None, + }; let raw_checked_type = context.substitutor.get_raw_type(tpl_id)?; if raw_checked_type.is_never() { return Some(LuaType::Never); } - let members = union_members(raw_checked_type)?; + let members = match &raw_checked_type { + LuaType::Union(union) => union.into_vec(), + LuaType::MultiLineUnion(multi) => multi + .get_unions() + .iter() + .map(|(member, _)| member.clone()) + .collect(), + _ => return None, + }; let mut result = LuaType::Never; for member in members { let mut member_substitutor = context.substitutor.clone(); - member_substitutor.replace_type(tpl_id, member, false); + member_substitutor.bind(tpl_id, TplBinding::ReplaceConstType(member)); let member_context = context.with_substitutor(&member_substitutor); - let member_result = instantiate_conditional_once(&member_context, conditional); + let member_result = instantiate_conditional_once(&member_context, frame, conditional); result = TypeOps::Union.apply(context.db, &result, &member_result); } Some(result) } -fn naked_checked_type_tpl_id(checked_type: &LuaType) -> Option { - match checked_type { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) if tpl.get_tpl_id().is_type() => { - Some(tpl.get_tpl_id()) - } - _ => None, - } -} - -fn union_members(ty: &LuaType) -> Option> { - match ty { - LuaType::Union(union) => Some(union.into_vec()), - LuaType::MultiLineUnion(multi) => Some( - multi - .get_unions() - .iter() - .map(|(member, _)| member.clone()) - .collect(), - ), - _ => None, - } -} - fn instantiate_true_branch( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, conditional: &LuaConditionalType, infer_assignments: HashMap, ) -> LuaType { if infer_assignments.is_empty() { - return instantiate_type_generic( - context.db, - conditional.get_true_type(), - context.substitutor, - ); + return instantiate_type_generic_inner(context, frame, conditional.get_true_type()); } let mut true_substitutor = context.substitutor.clone(); for (tpl_id, ty) in infer_assignments { - true_substitutor.insert_conditional_infer_type(tpl_id, ty); + true_substitutor.bind(tpl_id, TplBinding::ConditionalInferType(ty)); } - instantiate_type_generic(context.db, conditional.get_true_type(), &true_substitutor) + let true_context = context.with_substitutor(&true_substitutor); + instantiate_type_generic_inner(&true_context, frame, conditional.get_true_type()) } fn contains_conditional_infer(ty: &LuaType) -> bool { - ty.any_type(conditional_infer_tpl_id) -} - -fn conditional_infer_tpl_id(ty: &LuaType) -> bool { - matches!( - ty, - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) - if tpl.get_tpl_id().is_conditional_infer() - ) + ty.any_type(|inner| { + matches!( + inner, + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if tpl.get_tpl_id().is_conditional_infer() + ) + }) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -217,6 +278,10 @@ fn check_conditional_extends(db: &DbIndex, source: &LuaType, target: &LuaType) - return ConditionalCheck::True; } + if literal_extends_base_type(source, target) { + return ConditionalCheck::True; + } + if let LuaType::Union(union) = source { let mut result = ConditionalCheck::False; for member in union.into_vec() { @@ -241,6 +306,10 @@ fn check_conditional_extends(db: &DbIndex, source: &LuaType, target: &LuaType) - return ConditionalCheck::False; } + if is_deferred_conditional_operand(source) || is_deferred_conditional_operand(target) { + return ConditionalCheck::Both; + } + if check_type_compact_with_level( db, source, @@ -263,6 +332,25 @@ fn merge_conditional_check(left: ConditionalCheck, right: ConditionalCheck) -> C } } +fn literal_extends_base_type(source: &LuaType, target: &LuaType) -> bool { + matches!( + (source, target), + ( + LuaType::StringConst(_) | LuaType::DocStringConst(_), + LuaType::String + ) | ( + LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_), + LuaType::Integer + ) | ( + LuaType::IntegerConst(_) | LuaType::DocIntegerConst(_) | LuaType::FloatConst(_), + LuaType::Number, + ) | ( + LuaType::BooleanConst(_) | LuaType::DocBooleanConst(_), + LuaType::Boolean + ) + ) +} + fn collect_infer_assignments( db: &DbIndex, source: &LuaType, @@ -645,6 +733,8 @@ fn insert_infer_assignment( } fn finalize_infer_assignments( + context: &GenericInstantiateContext, + conditional: &LuaConditionalType, assignments: HashMap, ) -> HashMap { assignments @@ -653,29 +743,51 @@ fn finalize_infer_assignments( candidates .covariant .or(candidates.contravariant) - .map(|ty| (tpl_id, ty)) + .map(|raw_candidate| { + let Some(param) = conditional.get_infer_params().get(tpl_id.get_idx()) else { + return (tpl_id, raw_candidate); + }; + + let tpl = GenericTpl::new( + tpl_id, + ArcIntern::new(param.name.clone()), + param.type_constraint.clone(), + param.default_type.clone(), + ); + ( + tpl_id, + finalize_inferred_tpl_candidate( + context.db, + &tpl, + &raw_candidate, + TplCandidateSource::ConstPreserving, + true, + true, + context.substitutor, + ), + ) + }) }) .collect() } fn instantiate_conditional_operand( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, operand: &LuaType, checked: bool, has_new: bool, ) -> LuaType { - let mut result = instantiate_type_generic_with_context(context, operand); + let mut result = instantiate_type_generic_inner(context, frame, operand); if let LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) = operand { let tpl_id = tpl_ref.get_tpl_id(); if let Some(raw) = context.substitutor.get_raw_type(tpl_id) { result = raw.clone(); - } else if checked && result.contains_tpl_node() { - result = LuaType::Unknown; + } else if checked && result.is_never() { + result = LuaType::Never; } } - result = actualize_unresolved_templates(result); - if has_new && let LuaType::Ref(id) | LuaType::Def(id) = &result && let Some(decl) = context.db.get_type_index().get_type_decl(id) @@ -688,147 +800,17 @@ fn instantiate_conditional_operand( result } -// 条件类型判定只消费已经实例化后的实际类型, 残留的普通模板引用在这里递归收敛为 `unknown`. -// `infer` pattern 也以模板引用表示, 必须保留下来供后续结构匹配绑定. -fn actualize_unresolved_templates(ty: LuaType) -> LuaType { - match ty { - LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { - if tpl.get_tpl_id().is_conditional_infer() { - // Conditional infer 是右侧 pattern 的占位孔, 不能像普通未解模板一样抹成 unknown. - LuaType::TplRef(tpl) - } else { - LuaType::Unknown - } - } - LuaType::StrTplRef(_) => LuaType::Unknown, - LuaType::Array(array) => LuaType::Array( - crate::LuaArrayType::new( - actualize_unresolved_templates(array.get_base().clone()), - array.get_len().clone(), - ) - .into(), - ), - LuaType::Tuple(tuple) => LuaType::Tuple( - LuaTupleType::new( - tuple - .get_types() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - tuple.status, - ) - .into(), - ), - LuaType::DocFunction(func) => LuaType::DocFunction( - crate::LuaFunctionType::new( - func.get_async_state(), - func.is_colon_define(), - func.is_variadic(), - func.get_params() - .iter() - .map(|(name, ty)| { - (name.clone(), ty.clone().map(actualize_unresolved_templates)) - }) - .collect(), - actualize_unresolved_templates(func.get_ret().clone()), - ) - .into(), - ), - LuaType::Object(object) => LuaType::Object( - LuaObjectType::new_with_fields( - object - .get_fields() - .iter() - .map(|(key, ty)| (key.clone(), actualize_unresolved_templates(ty.clone()))) - .collect(), - object - .get_index_access() - .iter() - .map(|(key, value)| { - ( - actualize_unresolved_templates(key.clone()), - actualize_unresolved_templates(value.clone()), - ) - }) - .collect(), - ) - .into(), - ), - LuaType::Union(union) => LuaType::from_vec( - union - .into_vec() - .into_iter() - .map(actualize_unresolved_templates) - .collect(), - ), - LuaType::MultiLineUnion(multi) => LuaType::from_vec( - multi - .get_unions() - .iter() - .map(|(ty, _)| actualize_unresolved_templates(ty.clone())) - .collect(), - ), - LuaType::Intersection(intersection) => LuaType::Intersection( - crate::LuaIntersectionType::new( - intersection - .get_types() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ) - .into(), - ), - LuaType::Generic(generic) => LuaType::Generic( - crate::LuaGenericType::new( - generic.get_base_type_id(), - generic - .get_params() - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ) - .into(), - ), - LuaType::TableGeneric(params) => LuaType::TableGeneric( - params - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect::>() - .into(), - ), - LuaType::Variadic(variadic) => LuaType::Variadic( - match variadic.deref() { - crate::VariadicType::Base(base) => { - crate::VariadicType::Base(actualize_unresolved_templates(base.clone())) - } - crate::VariadicType::Multi(types) => crate::VariadicType::Multi( - types - .iter() - .cloned() - .map(actualize_unresolved_templates) - .collect(), - ), - } - .into(), - ), - LuaType::TypeGuard(guard) => { - LuaType::TypeGuard(actualize_unresolved_templates(guard.deref().clone()).into()) - } - LuaType::Conditional(conditional) => LuaType::Conditional( - LuaConditionalType::new( - actualize_unresolved_templates(conditional.get_checked_type().clone()), - actualize_unresolved_templates(conditional.get_extends_type().clone()), - actualize_unresolved_templates(conditional.get_true_type().clone()), - actualize_unresolved_templates(conditional.get_false_type().clone()), - conditional.get_infer_params().to_vec(), - conditional.has_new, - ) - .into(), - ), - ty => ty, - } +fn is_deferred_conditional_operand(ty: &LuaType) -> bool { + ty.any_type(|inner| { + matches!( + inner, + LuaType::TplRef(_) + | LuaType::ConstTplRef(_) + | LuaType::StrTplRef(_) + | LuaType::SelfInfer + | LuaType::Conditional(_) + | LuaType::Mapped(_) + | LuaType::Call(_) + ) + }) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs new file mode 100644 index 000000000..21a9ad912 --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_mapped_type.rs @@ -0,0 +1,248 @@ +use hashbrown::{HashMap, HashSet}; +use std::ops::Deref; + +use crate::{ + GenericParam, LuaAliasCallKind, LuaMappedType, LuaMemberKey, LuaObjectType, LuaTupleStatus, + LuaTupleType, LuaType, TypeOps, VariadicType, +}; + +use super::{ + GenericInstantiateContext, GenericInstantiateFrame, instantiate_special_generic, + instantiate_type_generic_inner, key_type_to_member_key, +}; +use crate::semantic::generic::type_substitutor::TplBinding; + +pub(super) fn instantiate_mapped_type( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + mapped: &LuaMappedType, +) -> LuaType { + let Some(frame) = frame.enter() else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let Some(constraint) = mapped.param.1.type_constraint.as_ref() else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let Some(key_domain) = resolve_mapped_key_domain(context, frame, constraint) else { + return instantiate_mapped_residual(context, frame, mapped); + }; + + let empty_object = + || LuaType::Object(LuaObjectType::new_with_fields(HashMap::new(), Vec::new()).into()); + + if key_domain.keys.is_empty() { + return empty_object(); + } + + let key_count = key_domain.keys.len(); + let mut visited = HashSet::with_capacity(key_count); + let mut field_indices: HashMap = HashMap::with_capacity(key_count); + let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::with_capacity(key_count); + let mut index_access: Vec<(LuaType, LuaType)> = Vec::with_capacity(key_count); + let mut local_substitutor = context.substitutor.clone(); + + for key_ty in key_domain.keys { + if !visited.insert(key_ty.clone()) { + continue; + } + + local_substitutor.bind(mapped.param.0, TplBinding::ReplaceConstType(key_ty.clone())); + let local_context = context.with_substitutor(&local_substitutor); + let mut value_ty = instantiate_type_generic_inner(&local_context, frame, &mapped.value); + if mapped.is_optional { + value_ty = TypeOps::Union.apply(context.db, &value_ty, &LuaType::Nil); + } + + if let Some(member_key) = key_type_to_member_key(&key_ty) { + if let Some(index) = field_indices.get(&member_key).copied() { + let (_, existing) = &mut fields[index]; + let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); + *existing = merged; + } else { + field_indices.insert(member_key.clone(), fields.len()); + fields.push((member_key, value_ty)); + } + } else { + index_access.push((key_ty, value_ty)); + } + } + + if fields.is_empty() && index_access.is_empty() { + return empty_object(); + } + + if key_domain.tuple_like + && index_access.is_empty() + && let Some(types) = mapped_tuple_types(&fields) + { + return LuaType::Tuple(LuaTupleType::new(types, LuaTupleStatus::InferResolve).into()); + } + + let field_map: HashMap = fields.into_iter().collect(); + LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()) +} + +struct MappedKeyDomain { + keys: Vec, + tuple_like: bool, +} + +fn resolve_mapped_key_domain( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + constraint: &LuaType, +) -> Option { + if let LuaType::Call(alias_call) = constraint + && alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 + { + let source = instantiate_type_generic_inner(context, frame, &alias_call.get_operands()[0]); + let keys = instantiate_special_generic::get_keyof_type(context.db, &source)?; + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&keys, &mut atoms) { + return None; + } + return Some(MappedKeyDomain { + keys: atoms, + tuple_like: source.is_tuple() || matches!(source, LuaType::Variadic(_)), + }); + } + + let instantiated = instantiate_type_generic_inner(context, frame, constraint); + match &instantiated { + LuaType::Call(alias_call) + if alias_call.get_call_kind() == LuaAliasCallKind::KeyOf + && alias_call.get_operands().len() == 1 => + { + let source = &alias_call.get_operands()[0]; + let keys = instantiate_special_generic::get_keyof_type(context.db, source)?; + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&keys, &mut atoms) { + return None; + } + Some(MappedKeyDomain { + keys: atoms, + tuple_like: source.is_tuple() || matches!(source, LuaType::Variadic(_)), + }) + } + _ => { + let mut atoms = Vec::new(); + if !collect_mapped_key_atoms(&instantiated, &mut atoms) { + return None; + } + Some(MappedKeyDomain { + tuple_like: instantiated.is_tuple(), + keys: atoms, + }) + } + } +} + +fn instantiate_mapped_residual( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + mapped: &LuaMappedType, +) -> LuaType { + let param = ( + mapped.param.0, + GenericParam::new( + mapped.param.1.name.clone(), + mapped + .param + .1 + .type_constraint + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, frame, ty)), + mapped + .param + .1 + .default_type + .as_ref() + .map(|ty| instantiate_type_generic_inner(context, frame, ty)), + mapped.param.1.attributes.clone(), + ), + ); + + LuaType::Mapped( + LuaMappedType::new( + param, + instantiate_type_generic_inner(context, frame, &mapped.value), + mapped.is_readonly, + mapped.is_optional, + ) + .into(), + ) +} + +fn mapped_tuple_types(fields: &[(LuaMemberKey, LuaType)]) -> Option> { + let mut indexed = fields + .iter() + .filter_map(|(key, ty)| match key { + LuaMemberKey::Integer(i) => Some((*i, ty.clone())), + _ => None, + }) + .collect::>(); + + if indexed.len() != fields.len() { + return None; + } + + indexed.sort_by_key(|(index, _)| *index); + let starts_at_zero = indexed.first().is_some_and(|(index, _)| *index == 0); + let expected_start = if starts_at_zero { 0 } else { 1 }; + for (offset, (index, _)) in indexed.iter().enumerate() { + if *index != expected_start + offset as i64 { + return None; + } + } + + Some(indexed.into_iter().map(|(_, ty)| ty).collect()) +} + +fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) -> bool { + match key_ty { + LuaType::Union(union) => { + for member in union.into_vec() { + if !collect_mapped_key_atoms(&member, acc) { + return false; + } + } + true + } + LuaType::MultiLineUnion(multi) => { + for (member, _) in multi.get_unions() { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), + VariadicType::Multi(types) => { + for member in types { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + }, + LuaType::Tuple(tuple) => { + for member in tuple.get_types() { + if !collect_mapped_key_atoms(member, acc) { + return false; + } + } + true + } + LuaType::Never => true, + LuaType::Unknown | LuaType::Call(_) | LuaType::Mapped(_) => false, + _ => { + acc.push(key_ty.clone()); + true + } + } +} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs index 57fae183d..80f10166f 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_special_generic.rs @@ -1,6 +1,6 @@ use crate::{ DbIndex, LuaAliasCallKind, LuaAliasCallType, LuaMemberInfo, LuaMemberKey, LuaObjectType, - LuaTupleStatus, LuaTupleType, LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, + LuaType, LuaTypeNode, TypeOps, VariadicType, get_member_map, semantic::{ generic::key_type_to_member_key, member::{find_members, infer_raw_member_type}, @@ -10,16 +10,20 @@ use crate::{ use hashbrown::HashMap; use std::{ops::Deref, vec}; -use super::{GenericInstantiateContext, TypeSubstitutor, instantiate_type_generic_with_context}; +use super::{ + GenericInstantiateContext, GenericInstantiateFrame, SubstitutorValue, TypeSubstitutor, + instantiate_type_generic_inner, +}; pub(super) fn instantiate_alias_call( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, alias_call: &LuaAliasCallType, ) -> LuaType { let operand_exprs = alias_call.get_operands(); let operands = operand_exprs .iter() - .map(|it| instantiate_type_generic_with_context(context, it)) + .map(|it| instantiate_type_generic_inner(context, frame, it)) .collect::>(); match alias_call.get_call_kind() { @@ -42,16 +46,12 @@ pub(super) fn instantiate_alias_call( return LuaType::Unknown; } - let members = get_keyof_members(context.db, &operands[0]).unwrap_or_default(); - let member_key_types = members - .iter() - .filter_map(|m| match &m.key { - LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), - LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), - _ => None, - }) - .collect::>(); - LuaType::Tuple(LuaTupleType::new(member_key_types, LuaTupleStatus::InferResolve).into()) + match get_keyof_type(context.db, &operands[0]) { + Some(key_type) => key_type, + None => { + LuaType::Call(LuaAliasCallType::new(LuaAliasCallKind::KeyOf, operands).into()) + } + } } // 条件类型不在此处理 LuaAliasCallKind::Extends => { @@ -76,7 +76,10 @@ pub(super) fn instantiate_alias_call( instantiate_select_call(&operands[0], &operands[1]) } - LuaAliasCallKind::Unpack => instantiate_unpack_call(context.db, &operands), + LuaAliasCallKind::Unpack => { + let operands = resolve_unpack_operands(context, frame, operand_exprs); + instantiate_unpack_call(context.db, &operands) + } LuaAliasCallKind::RawGet => { if operands.len() != 2 { return LuaType::Unknown; @@ -101,6 +104,29 @@ pub(super) fn instantiate_alias_call( } } +pub(super) fn get_keyof_type(db: &DbIndex, ty: &LuaType) -> Option { + let members = get_keyof_members(db, ty)?; + let member_key_types = members + .iter() + .filter_map(|m| match &m.key { + LuaMemberKey::Integer(i) => Some(LuaType::DocIntegerConst(*i)), + LuaMemberKey::Name(s) => Some(LuaType::DocStringConst(s.clone().into())), + LuaMemberKey::ExprType(typ) => Some(typ.clone()), + _ => None, + }) + .collect::>(); + + if member_key_types.is_empty() { + if members.is_empty() { + return Some(LuaType::Never); + } + + return None; + } + + Some(LuaType::from_vec(member_key_types)) +} + fn instantiate_merge_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.len() != 2 { return LuaType::Unknown; @@ -215,6 +241,48 @@ fn instantiate_select_call(source: &LuaType, index: &LuaType) -> LuaType { } } +fn resolve_unpack_operands( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + operand_exprs: &[LuaType], +) -> Vec { + operand_exprs + .iter() + .enumerate() + .map(|(index, operand)| { + if index != 0 { + return instantiate_type_generic_inner(context, frame, operand); + } + let raw = match operand { + LuaType::TplRef(tpl_ref) | LuaType::ConstTplRef(tpl_ref) => context + .substitutor + .get(tpl_ref.get_tpl_id()) + .and_then(|value| match value { + SubstitutorValue::None => None, + SubstitutorValue::Type { value, .. } => Some(value.raw().clone()), + SubstitutorValue::MultiTypes { values, .. } => Some(LuaType::Variadic( + VariadicType::Multi( + values.iter().map(|value| value.raw().clone()).collect(), + ) + .into(), + )), + SubstitutorValue::Params(params) => Some( + params + .first() + .unwrap_or(&(String::new(), None)) + .1 + .clone() + .unwrap_or(LuaType::Unknown), + ), + SubstitutorValue::MultiBase(base) => Some(base.clone()), + }), + _ => None, + }; + raw.unwrap_or_else(|| instantiate_type_generic_inner(context, frame, operand)) + }) + .collect() +} + fn instantiate_unpack_call(db: &DbIndex, operands: &[LuaType]) -> LuaType { if operands.is_empty() { return LuaType::Unknown; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index ae9d07e77..c15dd4c44 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -1,141 +1,83 @@ mod complete_generic_args; +mod infer_call_func_generic; +mod inference_widening; mod instantiate_conditional_generic; -mod instantiate_func_generic; +mod instantiate_mapped_type; mod instantiate_special_generic; -use hashbrown::{HashMap, HashSet}; -use std::{ops::Deref, sync::Arc}; +use hashbrown::HashMap; +use std::ops::Deref; use crate::{ - DbIndex, GenericTpl, GenericTplId, LuaArrayType, LuaMappedType, LuaMemberKey, - LuaOperatorMetaMethod, LuaSignatureId, LuaTupleStatus, LuaTupleType, LuaTypeDeclId, - LuaTypeNode, TypeOps, + DbIndex, GenericTpl, LuaArrayType, LuaMemberKey, LuaOperatorMetaMethod, LuaSignatureId, + LuaTupleType, LuaTypeDeclId, LuaTypeNode, db_index::{ LuaFunctionType, LuaGenericType, LuaIntersectionType, LuaObjectType, LuaType, LuaUnionType, VariadicType, }, - semantic::infer::InferFailReason, }; use super::type_substitutor::{ - GenericInstantiateContext, SubstitutorValue, TypeSubstitutor, UninferredTplPolicy, + GenericInstantiateContext, GenericInstantiateFrame, SubstitutorValue, TypeSubstitutor, + UninferredTplPolicy, }; pub use complete_generic_args::{ GenericArgumentCompletion, complete_type_generic_args, complete_type_generic_args_in_type, }; -pub use instantiate_func_generic::{build_self_type, infer_self_type, instantiate_func_generic}; +pub use infer_call_func_generic::{build_self_type, infer_call_func_generic, infer_self_type}; +pub(in crate::semantic::generic) use inference_widening::{ + TplCandidateSource, finalize_inferred_tpl_candidate, +}; +pub use inference_widening::{WideningContext, WideningGuard, widen_type_with_context}; +use instantiate_mapped_type::instantiate_mapped_type as instantiate_mapped_type_inner; pub use instantiate_special_generic::get_keyof_members; -pub(crate) fn collect_callable_overload_groups( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, -) -> Result<(), InferFailReason> { - let mut visiting_aliases = HashSet::new(); - collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) -} - -fn collect_callable_overload_groups_inner( - db: &DbIndex, - callable_type: &LuaType, - groups: &mut Vec>>, - visiting_aliases: &mut HashSet, -) -> Result<(), InferFailReason> { - match callable_type { - LuaType::Ref(type_id) | LuaType::Def(type_id) => { - let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { - return Ok(()); - }; - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - - let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(type_id); - result?; - } - LuaType::Generic(generic) => { - let type_id = generic.get_base_type_id(); - if !visiting_aliases.insert(type_id.clone()) { - return Ok(()); - } - let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); - let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { - visiting_aliases.remove(&type_id); - return Ok(()); - }; - - let result = if let Some(origin_type) = - type_decl.get_alias_origin(db, Some(&substitutor)) - { - collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) - } else { - Ok(()) - }; - visiting_aliases.remove(&type_id); - result?; - } - LuaType::Union(union) => { - for member in union.into_vec() { - collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; - } - } - LuaType::Intersection(intersection) => { - for member in intersection.get_types() { - collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; - } - } - LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), - LuaType::Signature(sig_id) => { - let Some(signature) = db.get_signature_index().get(sig_id) else { - return Ok(()); - }; - let mut overloads = signature.overloads.to_vec(); - overloads.push(signature.to_doc_func_type()); - groups.push(overloads); - } - _ => {} - } - - Ok(()) -} - pub fn instantiate_type_generic( db: &DbIndex, ty: &LuaType, substitutor: &TypeSubstitutor, ) -> LuaType { let context = GenericInstantiateContext::new(db, substitutor); - instantiate_type_generic_with_context(&context, ty) + let frame = context.root_frame(); + match ty { + LuaType::DocFunction(doc_func) => instantiate_doc_function(&context, frame, doc_func), + _ => instantiate_type_generic_inner(&context, frame, ty), + } } -pub(super) fn instantiate_type_generic_with_context( +pub(super) fn instantiate_type_generic_inner( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, ty: &LuaType, ) -> LuaType { + let Some(frame) = frame.enter() else { + return ty.clone(); + }; + match ty { - LuaType::Array(array_type) => instantiate_array(context, array_type.get_base()), - LuaType::Tuple(tuple) => instantiate_tuple(context, tuple), - LuaType::DocFunction(doc_func) => instantiate_doc_function_with_context( - &context.with_policy(UninferredTplPolicy::PreserveTplRef), + LuaType::Array(array_type) => instantiate_array(context, frame, array_type.get_base()), + LuaType::Tuple(tuple) => instantiate_tuple(context, frame, tuple), + LuaType::DocFunction(doc_func) => instantiate_doc_function( + context, + frame.with_policy(UninferredTplPolicy::PreserveTplRef), doc_func, ), - LuaType::Object(object) => instantiate_object(context, object), - LuaType::Union(union) => instantiate_union(context, union), - LuaType::Intersection(intersection) => instantiate_intersection(context, intersection), - LuaType::Generic(generic) => instantiate_generic_with_context(context, generic), - LuaType::TableGeneric(table_params) => instantiate_table_generic(context, table_params), - LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context), - LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context), - LuaType::Signature(sig_id) => instantiate_signature(context, sig_id), + LuaType::Object(object) => instantiate_object(context, frame, object), + LuaType::Union(union) => instantiate_union(context, frame, union), + LuaType::Intersection(intersection) => { + instantiate_intersection(context, frame, intersection) + } + LuaType::Generic(generic) => instantiate_generic(context, frame, generic), + LuaType::TableGeneric(table_params) => { + instantiate_table_generic(context, frame, table_params) + } + LuaType::TplRef(tpl) => instantiate_tpl_ref(tpl, context, frame), + LuaType::ConstTplRef(tpl) => instantiate_const_tpl_ref(tpl, context, frame), + LuaType::Signature(sig_id) => instantiate_signature(context, frame, sig_id), LuaType::Call(alias_call) => { - instantiate_special_generic::instantiate_alias_call(context, alias_call) + instantiate_special_generic::instantiate_alias_call(context, frame, alias_call) } - LuaType::Variadic(variadic) => instantiate_variadic_type(context, variadic), + LuaType::Variadic(variadic) => instantiate_variadic_type(context, frame, variadic), LuaType::SelfInfer => { if let Some(typ) = context.substitutor.get_self_type() { typ.clone() @@ -144,29 +86,34 @@ pub(super) fn instantiate_type_generic_with_context( } } LuaType::TypeGuard(guard) => { - let inner = instantiate_type_generic_with_context(context, guard.deref()); + let inner = instantiate_type_generic_inner(context, frame, guard.deref()); LuaType::TypeGuard(inner.into()) } LuaType::Conditional(conditional) => { - instantiate_conditional_generic::instantiate_conditional(context, conditional) + instantiate_conditional_generic::instantiate_conditional(context, frame, conditional) } - LuaType::Mapped(mapped) => instantiate_mapped_type(context, mapped.deref()), + LuaType::Mapped(mapped) => instantiate_mapped_type_inner(context, frame, mapped.deref()), _ => ty.clone(), } } -fn instantiate_types<'a, I>(context: &GenericInstantiateContext, types: I) -> Vec +fn instantiate_types<'a, I>( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + types: I, +) -> Vec where I: IntoIterator, { types .into_iter() - .map(|ty| instantiate_type_generic_with_context(context, ty)) + .map(|ty| instantiate_type_generic_inner(context, frame, ty)) .collect() } fn instantiate_type_pairs<'a, I>( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, pairs: I, ) -> Vec<(LuaType, LuaType)> where @@ -176,19 +123,27 @@ where .into_iter() .map(|(key, value)| { ( - instantiate_type_generic_with_context(context, key), - instantiate_type_generic_with_context(context, value), + instantiate_type_generic_inner(context, frame, key), + instantiate_type_generic_inner(context, frame, value), ) }) .collect() } -fn instantiate_array(context: &GenericInstantiateContext, base: &LuaType) -> LuaType { - let base = instantiate_type_generic_with_context(context, base); +fn instantiate_array( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + base: &LuaType, +) -> LuaType { + let base = instantiate_type_generic_inner(context, frame, base); LuaType::Array(LuaArrayType::from_base_type(base).into()) } -fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) -> LuaType { +fn instantiate_tuple( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + tuple: &LuaTupleType, +) -> LuaType { let mut new_types = Vec::new(); for t in tuple.get_types() { if let LuaType::Variadic(inner) = t { @@ -198,18 +153,20 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => new_types - .push(instantiate_uninferred_tpl_fallback(tpl, context)), - SubstitutorValue::MultiTypes(types) => { - for typ in types { - new_types.push(typ.clone()); - } - } + .push(instantiate_uninferred_tpl_fallback(tpl, context, frame)), SubstitutorValue::Params(params) => { for (_, ty) in params { new_types.push(ty.clone().unwrap_or(LuaType::Unknown)); } } - SubstitutorValue::Type(ty) => new_types.push(ty.default().clone()), + SubstitutorValue::MultiTypes { values, .. } => { + new_types.extend( + values.iter().map(|value| value.resolved().clone()), + ); + } + SubstitutorValue::Type { value, .. } => { + new_types.push(value.resolved().clone()) + } SubstitutorValue::MultiBase(base) => new_types.push(base.clone()), } } else { @@ -223,23 +180,15 @@ fn instantiate_tuple(context: &GenericInstantiateContext, tuple: &LuaTupleType) break; } - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, frame, t); new_types.push(t); } LuaType::Tuple(LuaTupleType::new(new_types, tuple.status).into()) } -pub fn instantiate_doc_function( - db: &DbIndex, - doc_func: &LuaFunctionType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_doc_function_with_context(&context, doc_func) -} - -fn instantiate_doc_function_with_context( +fn instantiate_doc_function( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, doc_func: &LuaFunctionType, ) -> LuaType { let tpl_func_params = doc_func.get_params(); @@ -258,19 +207,20 @@ fn instantiate_doc_function_with_context( match origin_param_type { LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Base(base) => match base { - LuaType::TplRef(tpl) => { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - let ty = instantiate_uninferred_tpl_fallback(tpl, context); + let ty = + instantiate_uninferred_tpl_fallback(tpl, context, frame); new_params.push((origin_param.0.clone(), Some(ty))); } - SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + SubstitutorValue::Type { value, .. } => { + let resolved_type = value.resolved().clone(); // 如果参数是 `...: T...` if origin_param.0 == "..." { // 类型是 tuple, 那么我们将展开 tuple - if let LuaType::Tuple(tuple) = resolved_type { + if let LuaType::Tuple(tuple) = &resolved_type { let base_index = new_params.len(); for (i, typ) in tuple.get_types().iter().enumerate() { let param_name = format!("var{}", base_index + i); @@ -288,7 +238,7 @@ fn instantiate_doc_function_with_context( new_params.push(( origin_param.0.clone(), Some(LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), + VariadicType::Base(resolved_type).into(), )), )); } @@ -297,10 +247,11 @@ fn instantiate_doc_function_with_context( new_params.push(param.clone()); } } - SubstitutorValue::MultiTypes(types) => { - for (i, typ) in types.iter().enumerate() { + SubstitutorValue::MultiTypes { values, .. } => { + for (i, value) in values.iter().enumerate() { let param_name = format!("var{}", i); - new_params.push((param_name, Some(typ.clone()))); + new_params + .push((param_name, Some(value.resolved().clone()))); } } _ => { @@ -318,7 +269,7 @@ fn instantiate_doc_function_with_context( } } LuaType::Generic(generic) => { - let new_type = instantiate_generic_with_context(context, generic); + let new_type = instantiate_generic(context, frame, generic); // 如果是 rest 参数且实例化后的类型是 tuple, 那么我们将展开 tuple if let LuaType::Tuple(tuple_type) = &new_type { let base_index = new_params.len(); @@ -336,13 +287,13 @@ fn instantiate_doc_function_with_context( VariadicType::Multi(_) => (), }, _ => { - let new_type = instantiate_type_generic_with_context(context, origin_param_type); + let new_type = instantiate_type_generic_inner(context, frame, origin_param_type); new_params.push((origin_param.0.clone(), Some(new_type))); } } } - let mut inst_ret_type = instantiate_type_generic_with_context(context, tpl_ret); + let mut inst_ret_type = instantiate_type_generic_inner(context, frame, tpl_ret); // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple if let LuaType::Variadic(_) = &&tpl_ret && let LuaType::Tuple(tuple) = &inst_ret_type @@ -380,52 +331,57 @@ fn instantiate_doc_function_with_context( ) } -fn instantiate_object(context: &GenericInstantiateContext, object: &LuaObjectType) -> LuaType { +fn instantiate_object( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + object: &LuaObjectType, +) -> LuaType { let new_fields = object .get_fields() .iter() .map(|(key, field)| { ( key.clone(), - instantiate_type_generic_with_context(context, field), + instantiate_type_generic_inner(context, frame, field), ) }) .collect::>(); - let new_index_access = instantiate_type_pairs(context, object.get_index_access().iter()); + let new_index_access = instantiate_type_pairs(context, frame, object.get_index_access().iter()); LuaType::Object(LuaObjectType::new_with_fields(new_fields, new_index_access).into()) } -fn instantiate_union(context: &GenericInstantiateContext, union: &LuaUnionType) -> LuaType { - LuaType::from_vec(instantiate_types(context, union.into_vec().iter())) +fn instantiate_union( + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, + union: &LuaUnionType, +) -> LuaType { + LuaType::from_vec(instantiate_types(context, frame, union.into_vec().iter())) } fn instantiate_intersection( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, intersection: &LuaIntersectionType, ) -> LuaType { LuaType::Intersection( - LuaIntersectionType::new(instantiate_types(context, intersection.get_types().iter())) - .into(), + LuaIntersectionType::new(instantiate_types( + context, + frame, + intersection.get_types().iter(), + )) + .into(), ) } -pub fn instantiate_generic( - db: &DbIndex, - generic: &LuaGenericType, - substitutor: &TypeSubstitutor, -) -> LuaType { - let context = GenericInstantiateContext::new(db, substitutor); - instantiate_generic_with_context(&context, generic) -} - -fn instantiate_generic_with_context( +fn instantiate_generic( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, generic: &LuaGenericType, ) -> LuaType { let generic_params = generic.get_params(); - let new_params = instantiate_types(context, generic_params.iter()); + let new_params = instantiate_types(context, frame, generic_params.iter()); let base = generic.get_base_type(); let type_decl_id = if let LuaType::Ref(id) = base { @@ -438,9 +394,14 @@ fn instantiate_generic_with_context( && let Some(type_decl) = context.db.get_type_index().get_type_decl(&type_decl_id) && type_decl.is_alias() { - let new_substitutor = TypeSubstitutor::from_alias(new_params.clone(), type_decl_id.clone()); - if let Some(origin) = type_decl.get_alias_origin(context.db, Some(&new_substitutor)) { - return origin; + let Some(alias_context) = context.enter_alias(&type_decl_id) else { + return LuaType::Generic(LuaGenericType::new(type_decl_id, new_params).into()); + }; + let new_substitutor = + TypeSubstitutor::from_alias(context.db, new_params.clone(), type_decl_id.clone()); + let alias_context = alias_context.with_substitutor(&new_substitutor); + if let Some(origin) = type_decl.get_alias_ref() { + return instantiate_type_generic_inner(&alias_context, frame, origin); } } @@ -449,41 +410,57 @@ fn instantiate_generic_with_context( fn instantiate_table_generic( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, table_params: &[LuaType], ) -> LuaType { - LuaType::TableGeneric(instantiate_types(context, table_params.iter()).into()) + LuaType::TableGeneric(instantiate_types(context, frame, table_params.iter()).into()) } fn instantiate_uninferred_tpl_fallback( tpl: &GenericTpl, context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, ) -> LuaType { // 一些情况下需要保留 TplRef, 例如高阶函数调用 - if context.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { + if frame.should_preserve_tpl_ref() && tpl.get_default_type().is_none() { return LuaType::TplRef(tpl.clone().into()); } // 显式默认值优先, 然后是 extends 约束, 最后才是 unknown. if let Some(default_type) = tpl.get_default_type() { - return instantiate_type_generic_with_context(context, default_type); + return instantiate_type_generic_inner(context, frame, default_type); } if let Some(constraint) = tpl.get_constraint() { - return instantiate_type_generic_with_context(context, constraint); + return instantiate_type_generic_inner(context, frame, constraint); } LuaType::Unknown } -fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { +fn instantiate_tpl_ref( + tpl: &GenericTpl, + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, +) -> LuaType { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); + return instantiate_uninferred_tpl_fallback(tpl, context, frame); + } + SubstitutorValue::Type { value, .. } => { + return value.resolved().clone(); } - SubstitutorValue::Type(ty) => return ty.default().clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + SubstitutorValue::MultiTypes { values, .. } => { + return LuaType::Variadic( + VariadicType::Multi( + values + .iter() + .map(|value| value.resolved().clone()) + .collect(), + ) + .into(), + ); } SubstitutorValue::Params(params) => { return params @@ -500,15 +477,29 @@ fn instantiate_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType::TplRef(tpl.clone().into()) } -fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateContext) -> LuaType { +fn instantiate_const_tpl_ref( + tpl: &GenericTpl, + context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, +) -> LuaType { if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - return instantiate_uninferred_tpl_fallback(tpl, context); + return instantiate_uninferred_tpl_fallback(tpl, context, frame); + } + SubstitutorValue::Type { value, .. } => { + return value.resolved().clone(); } - SubstitutorValue::Type(ty) => return ty.raw().clone(), - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); + SubstitutorValue::MultiTypes { values, .. } => { + return LuaType::Variadic( + VariadicType::Multi( + values + .iter() + .map(|value| value.resolved().clone()) + .collect(), + ) + .into(), + ); } SubstitutorValue::Params(params) => { return params @@ -527,20 +518,22 @@ fn instantiate_const_tpl_ref(tpl: &GenericTpl, context: &GenericInstantiateConte fn instantiate_signature( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, signature_id: &LuaSignatureId, ) -> LuaType { if let Some(signature) = context.db.get_signature_index().get(signature_id) { let origin_type = { let fake_doc_function = signature.to_doc_func_type(); - instantiate_doc_function_with_context(context, &fake_doc_function) + instantiate_doc_function(context, frame, &fake_doc_function) }; if signature.overloads.is_empty() { return origin_type; } else { let mut result = Vec::new(); for overload in signature.overloads.iter() { - result.push(instantiate_doc_function_with_context( + result.push(instantiate_doc_function( context, + frame, &(*overload).clone(), )); } @@ -554,6 +547,7 @@ fn instantiate_signature( fn instantiate_variadic_type( context: &GenericInstantiateContext, + frame: GenericInstantiateFrame, variadic: &VariadicType, ) -> LuaType { match variadic { @@ -562,28 +556,34 @@ fn instantiate_variadic_type( if let Some(value) = context.substitutor.get(tpl.get_tpl_id()) { match value { SubstitutorValue::None => { - let fallback = instantiate_uninferred_tpl_fallback(tpl, context); + let fallback = instantiate_uninferred_tpl_fallback(tpl, context, frame); return match fallback { LuaType::Variadic(_) | LuaType::Never => fallback, LuaType::Nil | LuaType::Any | LuaType::Unknown => fallback, _ => LuaType::Variadic(VariadicType::Base(fallback).into()), }; } - SubstitutorValue::Type(ty) => { - let resolved_type = ty.default(); + SubstitutorValue::Type { value, .. } => { + let resolved_type = value.resolved().clone(); if matches!( resolved_type, LuaType::Nil | LuaType::Any | LuaType::Unknown | LuaType::Never ) { - return resolved_type.clone(); + return resolved_type; } + return LuaType::Variadic(VariadicType::Base(resolved_type).into()); + } + SubstitutorValue::MultiTypes { values, .. } => { return LuaType::Variadic( - VariadicType::Base(resolved_type.clone()).into(), + VariadicType::Multi( + values + .iter() + .map(|value| value.resolved().clone()) + .collect(), + ) + .into(), ); } - SubstitutorValue::MultiTypes(types) => { - return LuaType::Variadic(VariadicType::Multi(types.clone()).into()); - } SubstitutorValue::Params(params) => { let types = params .iter() @@ -600,7 +600,7 @@ fn instantiate_variadic_type( } } LuaType::Generic(generic) => { - return instantiate_generic_with_context(context, generic); + return instantiate_generic(context, frame, generic); } _ => {} }, @@ -608,7 +608,7 @@ fn instantiate_variadic_type( if types.iter().any(LuaTypeNode::contains_tpl_node) { let mut new_types = Vec::new(); for t in types { - let t = instantiate_type_generic_with_context(context, t); + let t = instantiate_type_generic_inner(context, frame, t); match t { LuaType::Never => {} LuaType::Variadic(variadic) => match variadic.deref() { @@ -630,92 +630,6 @@ fn instantiate_variadic_type( LuaType::Variadic(variadic.clone().into()) } -fn instantiate_mapped_type(context: &GenericInstantiateContext, mapped: &LuaMappedType) -> LuaType { - let constraint = mapped - .param - .1 - .type_constraint - .as_ref() - .map(|ty| instantiate_type_generic_with_context(context, ty)); - - if let Some(constraint) = constraint { - let mut key_types = Vec::new(); - collect_mapped_key_atoms(&constraint, &mut key_types); - - let mut visited = HashSet::new(); - let mut fields: Vec<(LuaMemberKey, LuaType)> = Vec::new(); - let mut index_access: Vec<(LuaType, LuaType)> = Vec::new(); - - for key_ty in key_types { - if !visited.insert(key_ty.clone()) { - continue; - } - - let value_ty = instantiate_mapped_value(context, mapped, mapped.param.0, &key_ty); - - if let Some(member_key) = key_type_to_member_key(&key_ty) { - if let Some((_, existing)) = fields.iter_mut().find(|(key, _)| key == &member_key) { - let merged = LuaType::from_vec(vec![existing.clone(), value_ty]); - *existing = merged; - } else { - fields.push((member_key, value_ty)); - } - } else { - index_access.push((key_ty, value_ty)); - } - } - - if !fields.is_empty() || !index_access.is_empty() { - // key 从 0 开始递增才被视为元组 - if constraint.is_tuple() { - let mut index = 0; - let mut is_tuple = true; - for (key, _) in &fields { - if let LuaMemberKey::Integer(i) = key { - if *i != index { - is_tuple = false; - break; - } - index += 1; - } else { - is_tuple = false; - break; - } - } - if is_tuple { - let types = fields.into_iter().map(|(_, ty)| ty).collect(); - return LuaType::Tuple( - LuaTupleType::new(types, LuaTupleStatus::InferResolve).into(), - ); - } - } - let field_map: HashMap = fields.into_iter().collect(); - return LuaType::Object(LuaObjectType::new_with_fields(field_map, index_access).into()); - } - } - - instantiate_type_generic_with_context(context, &mapped.value) -} - -fn instantiate_mapped_value( - context: &GenericInstantiateContext, - mapped: &LuaMappedType, - tpl_id: GenericTplId, - replacement: &LuaType, -) -> LuaType { - let mut local_substitutor = context.substitutor.clone(); - local_substitutor.insert_type(tpl_id, replacement.clone(), true); - let local_context = context.with_substitutor(&local_substitutor); - let mut result = instantiate_type_generic_with_context(&local_context, &mapped.value); - // 根据 readonly 和 optional 属性进行处理 - if mapped.is_optional { - result = TypeOps::Union.apply(context.db, &result, &LuaType::Nil); - } - // TODO: 处理 readonly, 但目前 readonly 的实现存在问题, 这里我们先跳过 - - result -} - pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { match key_ty { LuaType::DocStringConst(s) => Some(LuaMemberKey::Name(s.deref().clone())), @@ -726,36 +640,6 @@ pub(super) fn key_type_to_member_key(key_ty: &LuaType) -> Option { } } -fn collect_mapped_key_atoms(key_ty: &LuaType, acc: &mut Vec) { - match key_ty { - LuaType::Union(union) => { - for member in union.into_vec() { - collect_mapped_key_atoms(&member, acc); - } - } - LuaType::MultiLineUnion(multi) => { - for (member, _) in multi.get_unions() { - collect_mapped_key_atoms(member, acc); - } - } - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => collect_mapped_key_atoms(base, acc), - VariadicType::Multi(types) => { - for member in types { - collect_mapped_key_atoms(member, acc); - } - } - }, - LuaType::Tuple(tuple) => { - for member in tuple.get_types() { - collect_mapped_key_atoms(member, acc); - } - } - LuaType::Unknown | LuaType::Never => {} - _ => acc.push(key_ty.clone()), - } -} - pub(super) fn get_default_constructor(db: &DbIndex, decl_id: &LuaTypeDeclId) -> Option { let ids = db .get_operator_index() diff --git a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs index 90e34baa3..582e1e9ee 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/mod.rs @@ -11,7 +11,6 @@ pub use call_constraint::{ }; use emmylua_parser::LuaAstNode; use emmylua_parser::LuaExpr; -pub(crate) use instantiate_type::collect_callable_overload_groups; pub use instantiate_type::*; use rowan::NodeOrToken; pub use tpl_context::TplContext; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/test.rs b/crates/emmylua_code_analysis/src/semantic/generic/test.rs index 21dee2f3f..be0188d9a 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/test.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/test.rs @@ -298,4 +298,37 @@ result = { "# )); } + + #[test] + fn test_123() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param x T + ---@return T + function f(x) + return x + end + + A = f("hello") + B = f({value = "hello"}) + C = B.value + "#, + ); + + let a_ty = ws.expr_ty("A"); + assert_eq!(ws.humanize_type(a_ty), "\"hello\""); + + let b_ty = ws.expr_ty("B"); + let b_desc = ws.humanize_type_detailed(b_ty); + assert!( + b_desc.contains("value: string"), + "unexpected type: {}", + b_desc + ); + + let c_ty = ws.expr_ty("C"); + assert_eq!(ws.humanize_type(c_ty), "string"); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs index 02200a4a6..bab917412 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_context.rs @@ -1,6 +1,10 @@ use emmylua_parser::LuaCallExpr; -use crate::{DbIndex, LuaInferCache, TypeSubstitutor}; +use super::instantiate_type::TplCandidateSource; +use crate::{ + DbIndex, GenericTplId, LuaInferCache, LuaType, TypeSubstitutor, + semantic::generic::type_substitutor::TplBinding, +}; #[derive(Debug)] pub struct TplContext<'a> { @@ -8,4 +12,66 @@ pub struct TplContext<'a> { pub cache: &'a mut LuaInferCache, pub substitutor: &'a mut TypeSubstitutor, pub call_expr: Option, + inference_top_level: bool, +} + +impl<'a> TplContext<'a> { + pub fn new( + db: &'a DbIndex, + cache: &'a mut LuaInferCache, + substitutor: &'a mut TypeSubstitutor, + call_expr: Option, + ) -> Self { + Self { + db, + cache, + substitutor, + call_expr, + inference_top_level: true, + } + } + + pub fn with_inference_top_level( + &mut self, + top_level: bool, + f: impl FnOnce(&mut Self) -> R, + ) -> R { + let previous = self.inference_top_level; + self.inference_top_level = previous && top_level; + let result = f(self); + self.inference_top_level = previous; + result + } + + pub(in crate::semantic::generic) fn insert_type( + &mut self, + tpl_id: GenericTplId, + replace_type: LuaType, + source: TplCandidateSource, + ) { + self.substitutor.bind( + tpl_id, + TplBinding::InferredType { + ty: replace_type, + source, + top_level: self.inference_top_level, + }, + ); + } + + pub(in crate::semantic::generic) fn insert_multi_types( + &mut self, + tpl_id: GenericTplId, + types: Vec, + source: TplCandidateSource, + ) { + self.substitutor.bind( + tpl_id, + TplBinding::InferredMultiTypes { + types, + source, + top_level: self.inference_top_level, + }, + ); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs index aa556ca88..75d498ca4 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/generic_tpl_pattern.rs @@ -1,6 +1,6 @@ use crate::{ InferFailReason, InferGuard, InferGuardRef, LuaGenericType, LuaType, LuaTypeNode, TplContext, - TypeSubstitutor, instantiate_generic, instantiate_type_generic, + TypeSubstitutor, instantiate_type_generic, semantic::generic::tpl_pattern::{ TplPatternMatchResult, tpl_pattern_match, variadic_tpl_pattern_match, }, @@ -49,6 +49,7 @@ fn generic_tpl_pattern_match_inner( .ok_or(InferFailReason::None)?; if target_decl.is_alias() { let substitutor = TypeSubstitutor::from_alias( + context.db, target_generic.get_params().clone(), target_base.clone(), ); @@ -125,7 +126,8 @@ fn generic_tpl_pattern_match_inner( _ => { // 对于 @alias 类型, 我们能拿到的 target 实际上很有可能是实例化后的类型, 因此我们需要实例化后再进行匹配 let substitutor = TypeSubstitutor::new(); - let typ = instantiate_generic(context.db, source_generic, &substitutor); + let generic_ty = LuaType::Generic(source_generic.clone().into()); + let typ = instantiate_type_generic(context.db, &generic_ty, &substitutor); if LuaType::from(source_generic.clone()) != typ { tpl_pattern_match(context, &typ, target)?; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 77900f2d4..9bf830086 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -23,7 +23,10 @@ use crate::{ }, }; -use super::type_substitutor::TypeSubstitutor; +use super::{ + instantiate_type::TplCandidateSource::{ConstPreserving, Plain}, + type_substitutor::{TplBinding, TypeSubstitutor}, +}; use std::collections::HashMap; type TplPatternMatchResult = Result<(), InferFailReason>; @@ -159,16 +162,12 @@ pub fn tpl_pattern_match( match pattern { LuaType::TplRef(tpl) => { if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target.clone(), true); + context.insert_type(tpl.get_tpl_id(), target.clone(), Plain); } } LuaType::ConstTplRef(tpl) => { if tpl.get_tpl_id().is_func() { - context - .substitutor - .insert_type(tpl.get_tpl_id(), target, false); + context.insert_type(tpl.get_tpl_id(), target, ConstPreserving); } } LuaType::StrTplRef(str_tpl) => { @@ -176,33 +175,45 @@ pub fn tpl_pattern_match( let prefix = str_tpl.get_prefix(); let suffix = str_tpl.get_suffix(); let type_name = SmolStr::new(format!("{}{}{}", prefix, s, suffix)); - context.substitutor.insert_type( + context.insert_type( str_tpl.get_tpl_id(), get_str_tpl_infer_type(&type_name), - true, + Plain, ); } } LuaType::Array(array_type) => { - array_tpl_pattern_match(context, array_type.get_base(), &target)?; + context.with_inference_top_level(false, |context| { + array_tpl_pattern_match(context, array_type.get_base(), &target) + })?; } LuaType::TableGeneric(table_generic_params) => { - table_generic_tpl_pattern_match(context, table_generic_params, &target)?; + context.with_inference_top_level(false, |context| { + table_generic_tpl_pattern_match(context, table_generic_params, &target) + })?; } LuaType::Generic(generic) => { - generic_tpl_pattern_match(context, generic, &target)?; + context.with_inference_top_level(false, |context| { + generic_tpl_pattern_match(context, generic, &target) + })?; } LuaType::Union(union) => { union_tpl_pattern_match(context, union, &target)?; } LuaType::DocFunction(doc_func) => { - func_tpl_pattern_match(context, doc_func, &target)?; + context.with_inference_top_level(false, |context| { + func_tpl_pattern_match(context, doc_func, &target) + })?; } LuaType::Tuple(tuple) => { - tuple_tpl_pattern_match(context, tuple, &target)?; + context.with_inference_top_level(false, |context| { + tuple_tpl_pattern_match(context, tuple, &target) + })?; } LuaType::Object(obj) => { - object_tpl_pattern_match(context, obj, &target)?; + context.with_inference_top_level(false, |context| { + object_tpl_pattern_match(context, obj, &target) + })?; } _ => {} } @@ -210,16 +221,6 @@ pub fn tpl_pattern_match( Ok(()) } -pub fn constant_decay(typ: LuaType) -> LuaType { - match &typ { - LuaType::FloatConst(_) => LuaType::Number, - LuaType::DocIntegerConst(_) | LuaType::IntegerConst(_) => LuaType::Integer, - LuaType::DocStringConst(_) | LuaType::StringConst(_) => LuaType::String, - LuaType::DocBooleanConst(_) | LuaType::BooleanConst(_) => LuaType::Boolean, - _ => typ, - } -} - fn object_tpl_pattern_match( context: &mut TplContext, origin_obj: &LuaObjectType, @@ -646,7 +647,7 @@ fn param_type_list_pattern_match_type_list( if i >= targets.len() { if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.insert_type(tpl_id, LuaType::Nil, Plain); } break; } @@ -655,12 +656,12 @@ fn param_type_list_pattern_match_type_list( let tpl_id = generic_tpl.get_tpl_id(); if let Some(inferred_type_value) = context.substitutor.get(tpl_id) { match inferred_type_value { - SubstitutorValue::Type(_) => { + SubstitutorValue::Type { .. } => { continue; } - SubstitutorValue::MultiTypes(types) => { - if types.len() > 1 { - target_offset += types.len() - 1; + SubstitutorValue::MultiTypes { values, .. } => { + if values.len() > 1 { + target_offset += values.len() - 1; } continue; } @@ -717,9 +718,7 @@ pub(crate) fn return_type_pattern_match_target_type( VariadicType::Base(source_base) => { if let LuaType::TplRef(type_ref) = source_base { let tpl_id = type_ref.get_tpl_id(); - context - .substitutor - .insert_type(tpl_id, target_base.clone(), true); + context.insert_type(tpl_id, target_base.clone(), Plain); } } VariadicType::Multi(source_multi) => { @@ -730,22 +729,14 @@ pub(crate) fn return_type_pattern_match_target_type( && let LuaType::TplRef(type_ref) = base { let tpl_id = type_ref.get_tpl_id(); - context.substitutor.insert_type( - tpl_id, - target_base.clone(), - true, - ); + context.insert_type(tpl_id, target_base.clone(), Plain); } break; } LuaType::TplRef(tpl_ref) => { let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type( - tpl_id, - target_base.clone(), - true, - ); + context.insert_type(tpl_id, target_base.clone(), Plain); } _ => {} } @@ -784,12 +775,14 @@ fn func_varargs_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - substitutor.insert_params( + substitutor.bind( tpl_id, - target_rest_params - .iter() - .map(|(n, t)| (n.clone(), t.clone())) - .collect(), + TplBinding::VariadicParams( + target_rest_params + .iter() + .map(|(n, t)| (n.clone(), t.clone())) + .collect(), + ), ); } } @@ -810,7 +803,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.insert_type(tpl_id, LuaType::Nil, Plain); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -820,42 +813,28 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.insert_type(tpl_id, LuaType::Nil, Plain); } 1 => { - context.substitutor.insert_type( - tpl_id, - types[0].clone(), - true, - ); + context.insert_type(tpl_id, types[0].clone(), Plain); } _ => { - context.substitutor.insert_multi_types( - tpl_id, - types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); + context.insert_multi_types(tpl_id, types.to_vec(), Plain); } }, VariadicType::Base(base) => { - context.substitutor.insert_multi_base(tpl_id, base.clone()); + context + .substitutor + .bind(tpl_id, TplBinding::VariadicBase(base.clone())); } }, arg => { - context.substitutor.insert_type(tpl_id, arg.clone(), true); + context.insert_type(tpl_id, arg.clone(), Plain); } } } _ => { - context.substitutor.insert_multi_types( - tpl_id, - target_rest_types - .iter() - .map(|t| constant_decay(t.clone())) - .collect(), - ); + context.insert_multi_types(tpl_id, target_rest_types.to_vec(), Plain); } } } @@ -863,19 +842,17 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, false); + context.insert_type(tpl_id, LuaType::Nil, ConstPreserving); } 1 => { - context.substitutor.insert_type( - tpl_id, - target_rest_types[0].clone(), - false, - ); + context.insert_type(tpl_id, target_rest_types[0].clone(), ConstPreserving); } _ => { - context - .substitutor - .insert_multi_types(tpl_id, target_rest_types.to_vec()); + context.insert_multi_types( + tpl_id, + target_rest_types.to_vec(), + ConstPreserving, + ); } } } @@ -895,7 +872,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.get(i) { Some(t) => { - context.substitutor.insert_type(tpl_id, t.clone(), true); + context.insert_type(tpl_id, t.clone(), Plain); } None => { break; @@ -946,9 +923,10 @@ fn tuple_tpl_pattern_match( VariadicType::Base(base) => { if let LuaType::TplRef(tpl_ref) = base { let tpl_id = tpl_ref.get_tpl_id(); - context - .substitutor - .insert_multi_base(tpl_id, target_array_base.get_base().clone()); + context.substitutor.bind( + tpl_id, + TplBinding::VariadicBase(target_array_base.get_base().clone()), + ); } } VariadicType::Multi(_) => {} diff --git a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs index b045bda1d..6dbf090da 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/type_substitutor.rs @@ -1,7 +1,11 @@ use hashbrown::{HashMap, HashSet}; -use super::tpl_pattern::constant_decay; -use crate::{DbIndex, GenericTplId, LuaType, LuaTypeDeclId}; +use super::instantiate_type::{TplCandidateSource, finalize_inferred_tpl_candidate}; +use crate::{DbIndex, GenericTpl, GenericTplId, LuaType, LuaTypeDeclId}; +use std::sync::Arc; + +const MAX_INSTANTIATION_DEPTH: usize = 128; +const MAX_ALIAS_STACK: usize = 32; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) enum UninferredTplPolicy { @@ -11,11 +15,35 @@ pub(super) enum UninferredTplPolicy { PreserveTplRef, } +pub(in crate::semantic::generic) enum TplBinding { + FinalizedType(LuaType), + InferredType { + ty: LuaType, + source: TplCandidateSource, + top_level: bool, + }, + ReplaceConstType(LuaType), + ConditionalInferType(LuaType), + VariadicParams(Vec<(String, Option)>), + InferredMultiTypes { + types: Vec, + source: TplCandidateSource, + top_level: bool, + }, + VariadicBase(LuaType), +} + #[derive(Debug)] pub struct GenericInstantiateContext<'a> { pub db: &'a DbIndex, pub substitutor: &'a TypeSubstitutor, + alias_stack: Arc<[LuaTypeDeclId]>, +} + +#[derive(Debug, Clone, Copy)] +pub(super) struct GenericInstantiateFrame { policy: UninferredTplPolicy, + depth: usize, } impl<'a> GenericInstantiateContext<'a> { @@ -23,32 +51,68 @@ impl<'a> GenericInstantiateContext<'a> { Self { db, substitutor, - policy: UninferredTplPolicy::Fallback, + alias_stack: Arc::from([]), } } - pub(super) fn with_policy(&self, policy: UninferredTplPolicy) -> GenericInstantiateContext<'a> { - GenericInstantiateContext { - db: self.db, - substitutor: self.substitutor, - policy, + pub(super) fn root_frame(&self) -> GenericInstantiateFrame { + GenericInstantiateFrame { + policy: UninferredTplPolicy::Fallback, + depth: 0, } } - pub fn with_substitutor<'b>( + pub(super) fn with_substitutor<'b>( &'b self, substitutor: &'b TypeSubstitutor, ) -> GenericInstantiateContext<'b> { GenericInstantiateContext { db: self.db, substitutor, - policy: self.policy, + alias_stack: self.alias_stack.clone(), + } + } + + pub(super) fn enter_alias( + &self, + alias_type_id: &LuaTypeDeclId, + ) -> Option> { + if self.alias_stack.len() >= MAX_ALIAS_STACK + || self.alias_stack.iter().any(|id| id == alias_type_id) + { + return None; } + + let mut alias_stack = Vec::with_capacity(self.alias_stack.len() + 1); + alias_stack.extend(self.alias_stack.iter().cloned()); + alias_stack.push(alias_type_id.clone()); + Some(GenericInstantiateContext { + db: self.db, + substitutor: self.substitutor, + alias_stack: Arc::from(alias_stack), + }) + } +} + +impl GenericInstantiateFrame { + pub(super) fn with_policy(self, policy: UninferredTplPolicy) -> Self { + Self { policy, ..self } } pub fn should_preserve_tpl_ref(&self) -> bool { self.policy == UninferredTplPolicy::PreserveTplRef } + + pub(super) fn enter(self) -> Option { + if self.depth >= MAX_INSTANTIATION_DEPTH { + return None; + } + + Some(Self { + depth: self.depth + 1, + ..self + }) + } } #[derive(Debug, Clone)] @@ -78,7 +142,11 @@ impl TypeSubstitutor { for (i, ty) in type_array.into_iter().enumerate() { tpl_replace_map.insert( GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(ty, true), + source: TplCandidateSource::Finalized, + top_level: true, + }, ); } Self { @@ -88,14 +156,29 @@ impl TypeSubstitutor { } } - pub fn from_alias(type_array: Vec, alias_type_id: LuaTypeDeclId) -> Self { + pub fn from_alias( + db: &DbIndex, + type_array: Vec, + alias_type_id: LuaTypeDeclId, + ) -> Self { + let params = db.get_type_index().get_generic_params(&alias_type_id); + let mut tpl_replace_map = HashMap::new(); for (i, ty) in type_array.into_iter().enumerate() { + let tpl_id = params + .and_then(|params| params.get(i)) + .and_then(|param| param.tpl_id) + .unwrap_or(GenericTplId::Type(i as u32)); tpl_replace_map.insert( - GenericTplId::Type(i as u32), - SubstitutorValue::Type(SubstitutorTypeValue::new(ty, true)), + tpl_id, + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(ty, true), + source: TplCandidateSource::Finalized, + top_level: true, + }, ); } + Self { tpl_replace_map, alias_type_id: Some(alias_type_id), @@ -103,7 +186,7 @@ impl TypeSubstitutor { } } - pub fn add_need_infer_tpls(&mut self, tpl_ids: HashSet) { + pub fn prepare_inference_slots(&mut self, tpl_ids: HashSet) { for tpl_id in tpl_ids { // conditional infer id 只属于条件类型内部匹配, 不参与普通调用/类型泛型推导. if tpl_id.is_conditional_infer() { @@ -116,62 +199,104 @@ impl TypeSubstitutor { } } - pub fn is_infer_all_tpl(&self) -> bool { - for value in self.tpl_replace_map.values() { - if let SubstitutorValue::None = value { - return false; - } - } - true - } - - pub fn insert_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType, decay: bool) { - // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. - if tpl_id.is_conditional_infer() { - return; - } - - self.insert_type_value(tpl_id, SubstitutorTypeValue::new(replace_type, decay)); - } - - pub(super) fn replace_type( - &mut self, - tpl_id: GenericTplId, - replace_type: LuaType, - decay: bool, - ) { - if tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, decay)), - ); + pub fn has_unresolved_inference_slots(&self) -> bool { + self.tpl_replace_map + .values() + .any(|value| matches!(value, SubstitutorValue::None)) } - pub fn insert_conditional_infer_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { - // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. - if !tpl_id.is_conditional_infer() { - return; - } - - self.tpl_replace_map.insert( - tpl_id, - SubstitutorValue::Type(SubstitutorTypeValue::new(replace_type, false)), - ); + pub fn bind_type(&mut self, tpl_id: GenericTplId, replace_type: LuaType) { + self.bind(tpl_id, TplBinding::FinalizedType(replace_type)); } - fn insert_type_value(&mut self, tpl_id: GenericTplId, value: SubstitutorTypeValue) { - if !self.can_insert_type(tpl_id) { - return; + pub(in crate::semantic::generic) fn bind(&mut self, tpl_id: GenericTplId, binding: TplBinding) { + match binding { + TplBinding::ConditionalInferType(replace_type) => { + // 只有 conditional true 分支提交 infer 结果时允许写入 scoped conditional infer id. + if !tpl_id.is_conditional_infer() { + return; + } + + self.tpl_replace_map.insert( + tpl_id, + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(replace_type, false), + source: TplCandidateSource::ConstPreserving, + top_level: true, + }, + ); + } + TplBinding::ReplaceConstType(replace_type) => { + if tpl_id.is_conditional_infer() { + return; + } + + self.tpl_replace_map.insert( + tpl_id, + SubstitutorValue::Type { + value: SubstitutorTypeValue::new(replace_type, false), + source: TplCandidateSource::ConstPreserving, + top_level: true, + }, + ); + } + binding => { + // 普通替换入口不能写入 conditional infer, 避免条件类型局部绑定泄露到外层. + if tpl_id.is_conditional_infer() || !self.can_bind(tpl_id) { + return; + } + + let value = match binding { + TplBinding::FinalizedType(replace_type) => SubstitutorValue::Type { + value: SubstitutorTypeValue::new(replace_type, true), + source: TplCandidateSource::Finalized, + top_level: true, + }, + TplBinding::InferredType { + ty, + source, + top_level, + } => SubstitutorValue::Type { + value: SubstitutorTypeValue::new(ty, false), + source, + top_level, + }, + TplBinding::VariadicParams(params) => { + let params = params + .into_iter() + .map(|(name, ty)| (name, ty.map(into_ref_type))) + .collect(); + SubstitutorValue::Params(params) + } + TplBinding::InferredMultiTypes { + types, + source, + top_level, + } => SubstitutorValue::MultiTypes { + values: types + .into_iter() + .map(|ty| { + SubstitutorTypeValue::new( + ty, + source == TplCandidateSource::Finalized, + ) + }) + .collect(), + source, + top_level, + }, + TplBinding::VariadicBase(type_base) => SubstitutorValue::MultiBase(type_base), + TplBinding::ReplaceConstType(_) | TplBinding::ConditionalInferType(_) => { + unreachable!("handled before regular binding") + } + }; + + self.tpl_replace_map.insert(tpl_id, value); + } } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Type(value)); } - fn can_insert_type(&self, tpl_id: GenericTplId) -> bool { + fn can_bind(&self, tpl_id: GenericTplId) -> bool { if let Some(value) = self.tpl_replace_map.get(&tpl_id) { return value.is_none(); } @@ -179,61 +304,90 @@ impl TypeSubstitutor { true } - pub fn insert_params(&mut self, tpl_id: GenericTplId, params: Vec<(String, Option)>) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - let params = params - .into_iter() - .map(|(name, ty)| (name, ty.map(into_ref_type))) - .collect(); - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::Params(params)); - } - - pub fn insert_multi_types(&mut self, tpl_id: GenericTplId, types: Vec) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiTypes(types)); - } - - pub fn insert_multi_base(&mut self, tpl_id: GenericTplId, type_base: LuaType) { - if tpl_id.is_conditional_infer() { - return; - } - - if !self.can_insert_type(tpl_id) { - return; - } - - self.tpl_replace_map - .insert(tpl_id, SubstitutorValue::MultiBase(type_base)); - } - - pub fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { + pub(super) fn get(&self, tpl_id: GenericTplId) -> Option<&SubstitutorValue> { self.tpl_replace_map.get(&tpl_id) } pub fn get_raw_type(&self, tpl_id: GenericTplId) -> Option<&LuaType> { match self.tpl_replace_map.get(&tpl_id) { - Some(SubstitutorValue::Type(ty)) => Some(ty.raw()), + Some(SubstitutorValue::Type { value, .. }) => Some(value.raw()), _ => None, } } + pub(super) fn finalize_inferred_types<'a>( + &mut self, + db: &DbIndex, + generic_tpls: impl IntoIterator>, + return_type: &LuaType, + ) { + for tpl in generic_tpls { + let tpl_id = tpl.get_tpl_id(); + let return_top_level = is_tpl_at_top_level(db, return_type, tpl_id); + let Some(value) = self.tpl_replace_map.get(&tpl_id) else { + continue; + }; + + let finalized_value = match value { + SubstitutorValue::Type { + value, + source, + top_level, + } => { + if value.is_finalized() { + None + } else { + Some(SubstitutorValue::Type { + value: value.finalized( + db, + tpl.as_ref(), + *source, + *top_level, + return_top_level, + self, + ), + source: TplCandidateSource::Finalized, + top_level: true, + }) + } + } + SubstitutorValue::MultiTypes { + values, + source, + top_level, + } => { + if *source == TplCandidateSource::Finalized { + None + } else { + let values = values + .iter() + .map(|value| { + value.finalized( + db, + tpl.as_ref(), + *source, + *top_level, + return_top_level, + self, + ) + }) + .collect(); + Some(SubstitutorValue::MultiTypes { + values, + source: TplCandidateSource::Finalized, + top_level: true, + }) + } + } + _ => None, + }; + + if let Some(finalized_value) = finalized_value { + self.tpl_replace_map.insert(tpl_id, finalized_value); + } + } + } + pub fn check_recursion(&self, type_id: &LuaTypeDeclId) -> bool { if let Some(alias_type_id) = &self.alias_type_id && alias_type_id == type_id @@ -256,49 +410,67 @@ impl TypeSubstitutor { #[derive(Debug, Clone)] pub struct SubstitutorTypeValue { raw: LuaType, - decayed: DecayedType, -} - -#[derive(Debug, Clone)] -enum DecayedType { - Same, - Cached(LuaType), + finalized: Option, } impl SubstitutorTypeValue { - pub fn new(raw: LuaType, decay: bool) -> Self { + fn new(raw: LuaType, already_finalized: bool) -> Self { let raw = into_ref_type(raw); - let decayed = if decay { - let decayed = into_ref_type(constant_decay(raw.clone())); - if decayed == raw { - DecayedType::Same - } else { - DecayedType::Cached(decayed) - } - } else { - DecayedType::Same - }; - Self { raw, decayed } + let finalized = already_finalized.then(|| raw.clone()); + Self { raw, finalized } } pub fn raw(&self) -> &LuaType { &self.raw } - pub fn default(&self) -> &LuaType { - match &self.decayed { - DecayedType::Same => &self.raw, - DecayedType::Cached(decayed) => decayed, + pub(super) fn resolved(&self) -> &LuaType { + self.finalized.as_ref().unwrap_or(&self.raw) + } + + fn is_finalized(&self) -> bool { + self.finalized.is_some() + } + + fn finalized( + &self, + db: &DbIndex, + tpl: &GenericTpl, + source: TplCandidateSource, + top_level: bool, + return_top_level: bool, + substitutor: &TypeSubstitutor, + ) -> Self { + let finalized = finalize_inferred_tpl_candidate( + db, + tpl, + &self.raw, + source, + top_level, + return_top_level, + substitutor, + ); + Self { + raw: self.raw.clone(), + finalized: Some(finalized), } } } #[derive(Debug, Clone)] -pub enum SubstitutorValue { +pub(super) enum SubstitutorValue { None, - Type(SubstitutorTypeValue), + Type { + value: SubstitutorTypeValue, + source: TplCandidateSource, + top_level: bool, + }, Params(Vec<(String, Option)>), - MultiTypes(Vec), + MultiTypes { + values: Vec, + source: TplCandidateSource, + top_level: bool, + }, MultiBase(LuaType), } @@ -308,6 +480,79 @@ impl SubstitutorValue { } } +fn is_tpl_at_top_level(db: &DbIndex, ty: &LuaType, tpl_id: GenericTplId) -> bool { + is_tpl_at_top_level_with_guard(db, ty, tpl_id, &mut HashSet::new()) +} + +fn is_tpl_at_top_level_with_guard( + db: &DbIndex, + ty: &LuaType, + tpl_id: GenericTplId, + visited_aliases: &mut HashSet, +) -> bool { + match ty { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) => tpl.get_tpl_id() == tpl_id, + LuaType::Union(union) => union.into_vec().iter().any(|member| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::MultiLineUnion(multi) => multi.get_unions().iter().any(|(member, _)| { + let mut branch_aliases = visited_aliases.clone(); + is_tpl_at_top_level_with_guard(db, member, tpl_id, &mut branch_aliases) + }), + LuaType::Generic(generic) => { + let type_decl_id = generic.get_base_type_id_ref(); + let Some(alias_param) = + get_transparent_alias_param_index(db, type_decl_id, visited_aliases) + else { + return false; + }; + + generic.get_params().get(alias_param).is_some_and(|param| { + is_tpl_at_top_level_with_guard(db, param, tpl_id, visited_aliases) + }) + } + _ => false, + } +} + +fn get_transparent_alias_param_index( + db: &DbIndex, + type_decl_id: &LuaTypeDeclId, + visited_aliases: &mut HashSet, +) -> Option { + if !visited_aliases.insert(type_decl_id.clone()) { + return None; + } + + let type_decl = db.get_type_index().get_type_decl(type_decl_id)?; + if !type_decl.is_alias() { + return None; + }; + let origin = type_decl.get_alias_ref()?; + + match origin { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + LuaType::Generic(generic) => { + get_transparent_alias_param_index(db, generic.get_base_type_id_ref(), visited_aliases) + .and_then(|alias_param| generic.get_params().get(alias_param)) + .and_then(|param| match param { + LuaType::TplRef(tpl) | LuaType::ConstTplRef(tpl) + if matches!(tpl.get_tpl_id(), GenericTplId::Type(_)) => + { + Some(tpl.get_tpl_id().get_idx()) + } + _ => None, + }) + } + _ => None, + } +} + fn into_ref_type(ty: LuaType) -> LuaType { match ty { LuaType::Def(type_decl_id) => LuaType::Ref(type_decl_id), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index cd1360cf2..a3cb628af 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -8,7 +8,7 @@ use super::{ super::{InferGuard, LuaInferCache, instantiate_type_generic, resolve_signature}, InferFailReason, InferResult, }; -use crate::semantic::overload_resolve::callable_accepts_args; +use crate::semantic::overload_resolve::{callable_accepts_args, collect_callable_overload_groups}; use crate::{ AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, @@ -17,14 +17,11 @@ use crate::{ use crate::{ InferGuardRef, semantic::{ - generic::{ - TypeSubstitutor, collect_callable_overload_groups, get_tpl_ref_extend_type, - instantiate_doc_function, - }, + generic::{TypeSubstitutor, get_tpl_ref_extend_type}, infer::narrow::get_type_at_call_expr_inline_cast, }, }; -use crate::{build_self_type, infer_self_type, instantiate_func_generic, semantic::infer_expr}; +use crate::{build_self_type, infer_call_func_generic, infer_self_type, semantic::infer_expr}; use infer_require::infer_require_call; use infer_setmetatable::infer_setmetatable_call; @@ -134,14 +131,13 @@ pub fn infer_call_expr_func( }; let result = if let Ok(func_ty) = result { - let func_ty = match func_ty.get_ret() { - LuaType::Call(_) => { - match instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { - Ok(func_ty) => Arc::new(func_ty), - Err(_) => func_ty, - } + let func_ty = if func_ty.get_ret().contain_tpl() || func_ty.get_ret().is_call() { + match infer_call_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { + Ok(func_ty) => Arc::new(func_ty), + Err(_) => func_ty, } - _ => func_ty, + } else { + func_ty }; let func_ret = func_ty.get_ret(); @@ -223,7 +219,7 @@ fn infer_doc_function( call_expr: LuaCallExpr, ) -> InferCallFuncResult { if func.contain_tpl() { - let result = instantiate_func_generic(db, cache, func, call_expr)?; + let result = infer_call_func_generic(db, cache, func, call_expr)?; return Ok(Arc::new(result)); } @@ -271,9 +267,10 @@ fn filter_callable_overloads_by_call_args( let has_tpls = !callable_tpls.is_empty(); let mut substitutor = TypeSubstitutor::new(); - substitutor.add_need_infer_tpls(callable_tpls); + substitutor.prepare_inference_slots(callable_tpls); let match_func = if has_tpls { - match instantiate_doc_function(db, func, &substitutor) { + let func_ty = LuaType::DocFunction(func.clone()); + match instantiate_type_generic(db, &func_ty, &substitutor) { LuaType::DocFunction(doc_func) => doc_func, _ => func.clone(), } @@ -362,13 +359,15 @@ fn infer_type_doc_function( }; if has_generic_tpl { - let result = instantiate_func_generic(db, cache, &f, call_expr.clone())?; + let result = infer_call_func_generic(db, cache, &f, call_expr.clone())?; overloads.push(Arc::new(result)); } else if f.contain_self() { let mut substitutor = TypeSubstitutor::new(); let self_type = build_self_type(db, call_expr_type); substitutor.add_self_type(self_type); - if let LuaType::DocFunction(f) = instantiate_doc_function(db, &f, &substitutor) + let func_ty = LuaType::DocFunction(f.clone()); + if let LuaType::DocFunction(f) = + instantiate_type_generic(db, &func_ty, &substitutor) { overloads.push(f); } @@ -903,6 +902,30 @@ mod tests { assert_eq!(ws.expr_ty("payload"), ws.ty("string")); } + #[test] + fn test_top_level_generic_literal_keeps_function_param_and_return_consistent() { + let mut ws = VirtualWorkspace::new(); + let file_id = ws.def( + r#" + ---@generic T + ---@param value T + ---@return T + local function id(value) end + + id("hello") + "#, + ); + let call_expr = ws.get_node::(file_id); + let semantic_model = ws.analysis.compilation.get_semantic_model(file_id).unwrap(); + let func = semantic_model + .infer_call_expr_func(call_expr, None) + .unwrap(); + + let param_ty = func.get_params()[0].1.clone().unwrap(); + assert_eq!(ws.humanize_type(param_ty), "\"hello\""); + assert_eq!(ws.humanize_type(func.get_ret().clone()), "\"hello\""); + } + #[test] fn test_union_call_ignores_unresolved_alias_member() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs index 5a8fac6a1..915401539 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/mod.rs @@ -701,14 +701,22 @@ fn infer_generic_member( ) -> InferResult { let base_type = generic_type.get_base_type(); - let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); - if let LuaType::Ref(base_type_decl_id) = &base_type { let type_index = db.get_type_index(); - if let Some(type_decl) = type_index.get_type_decl(base_type_decl_id) - && type_decl.is_alias() - && let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) + let Some(type_decl) = type_index.get_type_decl(base_type_decl_id) else { + return Err(InferFailReason::None); + }; + let generic_params = generic_type.get_params(); + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, generic_params.clone(), base_type_decl_id.clone()) + } else { + TypeSubstitutor::from_type_array(generic_params.clone()) + }; + + if type_decl.is_alias() + && let Some(origin_type) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &substitutor)) { return infer_member_by_lookup(db, cache, &origin_type, lookup, &infer_guard.fork()); } @@ -724,11 +732,12 @@ fn infer_generic_member( if let Some(result) = result { return Ok(result); } - } - let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; + let member_type = infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard)?; + return Ok(instantiate_type_generic(db, &member_type, &substitutor)); + } - Ok(instantiate_type_generic(db, &member_type, &substitutor)) + infer_member_by_lookup(db, cache, &base_type, lookup, infer_guard) } fn infer_instance_member( @@ -995,17 +1004,24 @@ fn infer_member_by_index_generic( return Err(InferFailReason::None); }; let generic_params = generic.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); let type_index = db.get_type_index(); let type_decl = type_index .get_type_decl(&type_decl_id) .ok_or(InferFailReason::None)?; + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, generic_params.clone(), type_decl_id.clone()) + } else { + TypeSubstitutor::from_type_array(generic_params.clone()) + }; if type_decl.is_alias() { - if let Some(origin_type) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin_type) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + { return infer_member_by_operator_key_type( db, cache, - &instantiate_type_generic(db, &origin_type, &substitutor), + &origin_type, key_type, &infer_guard.fork(), ); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index e915915b1..7b9d10f3a 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -16,7 +16,7 @@ use crate::{ var_ref_id::get_var_expr_var_ref_id, }, }, - semantic::instantiate_func_generic, + semantic::infer_call_func_generic, }; pub fn get_type_at_call_expr( @@ -226,7 +226,7 @@ fn get_type_guard_call_info( let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { let Ok(inst_func) = cache.with_no_flow(|cache| { - instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) + infer_call_func_generic(db, cache, func_type.as_ref(), call_expr) }) else { return Ok(None); }; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 78c464852..61728c1e1 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,7 +7,7 @@ use crate::{ LuaSignature, LuaType, TypeOps, semantic::{ infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, - instantiate_func_generic, + infer_call_func_generic, }, }; @@ -579,7 +579,7 @@ fn instantiate_return_rows( return_type.clone(), ); match cache - .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) + .with_no_flow(|cache| infer_call_func_generic(db, cache, &func, call_expr.clone())) { Ok(instantiated) => instantiated.get_ret().clone(), Err(_) => return_type, diff --git a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs index 606ba27b5..58284dc3d 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/find_members.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/find_members.rs @@ -495,10 +495,17 @@ fn find_generic_members( .iter() .map(|param| ctx.instantiate_type(db, param)) .collect(); - let substitutor = TypeSubstitutor::from_type_array(instantiated_params); let type_decl = db.get_type_index().get_type_decl(&base_ref_id)?; + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, instantiated_params, base_ref_id.clone()) + } else { + TypeSubstitutor::from_type_array(instantiated_params) + }; let ctx_with_substitutor = ctx.with_substitutor(substitutor.clone()); - if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { + if let Some(origin) = type_decl + .get_alias_ref() + .map(|origin| instantiate_type_generic(db, origin, &substitutor)) + { return find_members_guard(db, &origin, &ctx_with_substitutor, filter); } diff --git a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs index cadd3988e..7edaffb2a 100644 --- a/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs +++ b/crates/emmylua_code_analysis/src/semantic/member/infer_raw_member.rs @@ -214,14 +214,18 @@ fn infer_generic_raw_member_type( ) -> RawGetMemberTypeResult { let base_ref_id = generic_type.get_base_type_id_ref(); let generic_params = generic_type.get_params(); - let substitutor = TypeSubstitutor::from_type_array(generic_params.clone()); let type_decl = db .get_type_index() .get_type_decl(&base_ref_id) .ok_or(InferFailReason::None)?; + let substitutor = if type_decl.is_alias() { + TypeSubstitutor::from_alias(db, generic_params.clone(), base_ref_id.clone()) + } else { + TypeSubstitutor::from_type_array(generic_params.clone()) + }; if let Some(origin) = type_decl.get_alias_origin(db, Some(&substitutor)) { - return infer_raw_member_type(db, &origin, member_key); + return infer_raw_member_type_guard(db, &origin, member_key, infer_guard); } let base_ref_type = LuaType::Ref(base_ref_id.clone()); diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs new file mode 100644 index 000000000..fd5f568af --- /dev/null +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/collect_callable_overloads.rs @@ -0,0 +1,85 @@ +use std::sync::Arc; + +use hashbrown::HashSet; + +use crate::db_index::{DbIndex, LuaFunctionType, LuaType, LuaTypeDeclId}; + +use super::super::{generic::TypeSubstitutor, infer::InferFailReason}; + +pub(crate) fn collect_callable_overload_groups( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, +) -> Result<(), InferFailReason> { + let mut visiting_aliases = HashSet::new(); + collect_callable_overload_groups_inner(db, callable_type, groups, &mut visiting_aliases) +} + +fn collect_callable_overload_groups_inner( + db: &DbIndex, + callable_type: &LuaType, + groups: &mut Vec>>, + visiting_aliases: &mut HashSet, +) -> Result<(), InferFailReason> { + match callable_type { + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + let Some(type_decl) = db.get_type_index().get_type_decl(type_id) else { + return Ok(()); + }; + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + + let result = if let Some(origin_type) = type_decl.get_alias_origin(db, None) { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(type_id); + result?; + } + LuaType::Generic(generic) => { + let type_id = generic.get_base_type_id(); + if !visiting_aliases.insert(type_id.clone()) { + return Ok(()); + } + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + let Some(type_decl) = db.get_type_index().get_type_decl(&type_id) else { + visiting_aliases.remove(&type_id); + return Ok(()); + }; + + let result = if let Some(origin_type) = + type_decl.get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overload_groups_inner(db, &origin_type, groups, visiting_aliases) + } else { + Ok(()) + }; + visiting_aliases.remove(&type_id); + result?; + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overload_groups_inner(db, &member, groups, visiting_aliases)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overload_groups_inner(db, member, groups, visiting_aliases)?; + } + } + LuaType::DocFunction(doc_func) => groups.push(vec![doc_func.clone()]), + LuaType::Signature(sig_id) => { + let Some(signature) = db.get_signature_index().get(sig_id) else { + return Ok(()); + }; + let mut overloads = signature.overloads.to_vec(); + overloads.push(signature.to_doc_func_type()); + groups.push(overloads); + } + _ => {} + } + + Ok(()) +} diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index a6447a91c..203d97870 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -1,3 +1,4 @@ +mod collect_callable_overloads; mod resolve_signature_by_args; use std::sync::Arc; @@ -8,10 +9,11 @@ use crate::db_index::{DbIndex, LuaFunctionType, LuaType}; use super::{ LuaInferCache, - generic::instantiate_func_generic, + generic::infer_call_func_generic, infer::{InferCallFuncResult, InferFailReason, infer_expr_list_types, try_infer_expr_no_flow}, }; +pub(crate) use collect_callable_overloads::collect_callable_overload_groups; pub(crate) use resolve_signature_by_args::{callable_accepts_args, resolve_signature_by_args}; pub fn resolve_signature( @@ -78,7 +80,7 @@ fn resolve_signature_by_generic( ) -> InferCallFuncResult { let mut instantiate_funcs = Vec::new(); for func in overloads { - let instantiate_func = instantiate_func_generic(db, cache, &func, call_expr.clone())?; + let instantiate_func = infer_call_func_generic(db, cache, &func, call_expr.clone())?; instantiate_funcs.push(Arc::new(instantiate_func)); } resolve_signature_by_args( diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs index e9827448b..763258b1f 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/complex_type/mod.rs @@ -37,8 +37,11 @@ pub fn check_complex_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + let substitutor = TypeSubstitutor::from_alias( + context.db, + generic.get_params().clone(), + base_id.clone(), + ); if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { return check_general_type_compact( context, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs index afee3eddc..e7af2fa1c 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/func_type.rs @@ -23,8 +23,11 @@ pub fn check_doc_func_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + let substitutor = TypeSubstitutor::from_alias( + context.db, + generic.get_params().clone(), + base_id.clone(), + ); if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { return check_general_type_compact( context, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs index 0929c7ed5..85bec1bb7 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/generic_type.rs @@ -24,9 +24,13 @@ pub fn check_generic_type_compact( .get_type_decl(&source_generic.get_base_type_id()) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(source_generic.get_params().clone(), base_id.clone()); - if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { + let substitutor = TypeSubstitutor::from_alias( + context.db, + source_generic.get_params().clone(), + base_id.clone(), + ); + if let Some(alias_ref) = decl.get_alias_ref() { + let alias_origin = instantiate_type_generic(context.db, alias_ref, &substitutor); return check_general_type_compact( context, &alias_origin, diff --git a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs index 6f994cb0c..70865eec2 100644 --- a/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/type_check/simple_type.rs @@ -310,8 +310,11 @@ pub fn check_simple_type_compact( if let Some(decl) = context.db.get_type_index().get_type_decl(&base_id) && decl.is_alias() { - let substitutor = - TypeSubstitutor::from_alias(generic.get_params().clone(), base_id.clone()); + let substitutor = TypeSubstitutor::from_alias( + context.db, + generic.get_params().clone(), + base_id.clone(), + ); if let Some(alias_origin) = decl.get_alias_origin(context.db, Some(&substitutor)) { diff --git a/crates/emmylua_ls/src/handlers/definition/goto_function.rs b/crates/emmylua_ls/src/handlers/definition/goto_function.rs index f60d9b395..971c74aae 100644 --- a/crates/emmylua_ls/src/handlers/definition/goto_function.rs +++ b/crates/emmylua_ls/src/handlers/definition/goto_function.rs @@ -1,6 +1,6 @@ use emmylua_code_analysis::{ LuaCompilation, LuaDeclId, LuaFunctionType, LuaSemanticDeclId, LuaSignature, LuaSignatureId, - LuaType, SemanticDeclLevel, SemanticModel, instantiate_func_generic, + LuaType, SemanticDeclLevel, SemanticModel, infer_call_func_generic, }; use emmylua_parser::{ LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, LuaSyntaxToken, LuaTokenKind, @@ -291,7 +291,7 @@ pub fn compare_function_types( call_expr: &LuaCallExpr, ) -> Option { if func.contain_tpl() { - let instantiated_func = instantiate_func_generic( + let instantiated_func = infer_call_func_generic( semantic_model.get_db(), &mut semantic_model.get_cache().borrow_mut(), func, diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 2402e078e..c8f3e4818 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -3,8 +3,8 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, - instantiate_func_generic, try_extract_signature_id_from_field, + TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, infer_call_func_generic, + instantiate_type_generic, try_extract_signature_id_from_field, }; use crate::handlers::hover::{ @@ -104,7 +104,7 @@ fn build_function_call_hover( signature.get_type_params(), signature.get_return_type(), ); - let instantiated_signature = instantiate_func_generic( + let instantiated_signature = infer_call_func_generic( db, &mut builder.semantic_model.get_cache().borrow_mut(), &base_function, @@ -486,7 +486,7 @@ fn instantiate_call_return_overloads( row_return_type, ); let instantiated_row = - instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) + infer_call_func_generic(db, &mut cache, &row_function, call_expr.clone()) .ok() .map(|func| match func.get_ret() { LuaType::Variadic(variadic) => match variadic.as_ref() { @@ -702,8 +702,8 @@ fn hover_instantiate_function_type( return None; } match typ { - LuaType::DocFunction(f) => { - if let LuaType::DocFunction(f) = instantiate_doc_function(db, f, substitutor) { + LuaType::DocFunction(_) => { + if let LuaType::DocFunction(f) = instantiate_type_generic(db, typ, substitutor) { Some(f) } else { None diff --git a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs index d4cd08773..61b81de1f 100644 --- a/crates/emmylua_ls/src/handlers/hover/hover_builder.rs +++ b/crates/emmylua_ls/src/handlers/hover/hover_builder.rs @@ -336,7 +336,7 @@ pub fn substitutor_form_expr( let mut substitutor = TypeSubstitutor::new(); if let LuaType::Generic(generic) = prefix_type { for (i, param) in generic.get_params().iter().enumerate() { - substitutor.insert_type(GenericTplId::Type(i as u32), param.clone(), true); + substitutor.bind_type(GenericTplId::Type(i as u32), param.clone()); } return Some(substitutor); } else { diff --git a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs index 8132c2f3e..c7f6c9066 100644 --- a/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs +++ b/crates/emmylua_ls/src/handlers/test/semantic_token_test.rs @@ -28,19 +28,6 @@ mod tests { result } - fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { - let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); - for i in 0..600 { - let table_key = 3_121_212; - let field_key = 1_111_112 + i; - content.push_str(&format!( - "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", - i + 1, - )); - } - content - } - #[gtest] fn test_1() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); @@ -148,7 +135,22 @@ m.foo() Ok(()) } + #[cfg(feature = "full-test")] + fn make_issue_1028_repeated_prefix_guard_chain_content() -> String { + let mut content = String::from("V_cfad19afc42b = V_cfad19afc42b or {}\n"); + for i in 0..600 { + let table_key = 3_121_212; + let field_key = 1_111_112 + i; + content.push_str(&format!( + "if V_cfad19afc42b[{table_key}] and V_cfad19afc42b[{table_key}][{field_key}] then\n V_cfad19afc42b[{table_key}][{field_key}][\"__STR_{i}__\"] = \"__STR_{}__\"\nend\n\n", + i + 1, + )); + } + content + } + #[gtest] + #[cfg(feature = "full-test")] fn test_issue_1028_i18n_semantic_tokens_repeated_prefix_guard_chain() -> Result<()> { let mut ws = ProviderVirtualWorkspace::new(); let content = make_issue_1028_repeated_prefix_guard_chain_content();