diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs index 6425ad89e..9392ba728 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs @@ -71,12 +71,12 @@ where pub(in crate::compilation::analyzer) fn analyze_func_body_missing_return_flags_with( body: LuaBlock, infer_expr_type: &mut F, -) -> Result<(bool, bool, bool), InferFailReason> +) -> Result<(bool, bool), InferFailReason> where F: FnMut(&LuaExpr) -> Result, { let flow = analyze_block_returns(body, infer_expr_type)?; - Ok((flow.can_fall_through, flow.can_break, flow.is_infinite)) + Ok((flow.can_fall_through, flow.can_break)) } fn analyze_block_returns( diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs index a5b9bc2f1..13ff48003 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/mod.rs @@ -20,7 +20,7 @@ use unresolve::UnResolve; pub(super) fn analyze_func_body_missing_return_flags_with( body: LuaBlock, infer_expr_type: &mut F, -) -> Result<(bool, bool, bool), InferFailReason> +) -> Result<(bool, bool), InferFailReason> where F: FnMut(&LuaExpr) -> Result, { diff --git a/crates/emmylua_code_analysis/src/compilation/mod.rs b/crates/emmylua_code_analysis/src/compilation/mod.rs index 86121007a..601f12d09 100644 --- a/crates/emmylua_code_analysis/src/compilation/mod.rs +++ b/crates/emmylua_code_analysis/src/compilation/mod.rs @@ -12,7 +12,7 @@ use emmylua_parser::{LuaBlock, LuaExpr}; pub(crate) fn analyze_func_body_missing_return_flags_with( body: LuaBlock, infer_expr_type: &mut F, -) -> Result<(bool, bool, bool), InferFailReason> +) -> Result<(bool, bool), InferFailReason> where F: FnMut(&LuaExpr) -> Result, { diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 25cca95d7..728e8c141 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -201,6 +201,84 @@ mod test { assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); } + #[test] + fn test_pcall_narrows_type_guarded_callable_return_after_error_guard() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@param a string|fun(): integer + local function foo(a) + if type(a) == "string" then + return + end + + local ok, result = pcall(a) + if not ok then + return + end + + narrowed = result + end + "#, + ); + + let narrowed = ws.expr_ty("narrowed"); + assert_eq!(ws.humanize_type(narrowed), "integer"); + } + + #[test] + fn test_pcall_narrows_type_guarded_callable_return_with_forwarded_arg() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@param a string|fun(value: integer): string + local function foo(a) + if type(a) == "string" then + return + end + + local ok, result = pcall(a, 1) + if not ok then + return + end + + narrowed = result + end + "#, + ); + + let narrowed = ws.expr_ty("narrowed"); + assert_eq!(ws.humanize_type(narrowed), "string"); + } + + #[test] + fn test_pcall_narrows_type_guarded_callable_return_with_table_arg() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@param cb string|fun(a: {}): integer + local function foo(cb) + if type(cb) == "string" then + return + end + + local ok, result = pcall(cb, {}) + if not ok then + return + end + + narrowed = result + end + "#, + ); + + let narrowed = ws.expr_ty("narrowed"); + assert_eq!(ws.humanize_type(narrowed), "integer"); + } + #[test] fn test_pcall_any_callable_splits_success_unknown_and_failure_string() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs index b9a3709b5..ce3e362e8 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs @@ -111,20 +111,19 @@ fn check_missing_return( // 检测缺少返回语句需要处理 if while if min_expected_return_count > 0 { let range = if let Some(block) = closure_expr.get_block() { - let (can_fall_through, can_break, is_infinite) = - analyze_func_body_missing_return_flags_with( - block.clone(), - &mut |expr: &LuaExpr| { - Ok(semantic_model - .infer_expr(expr.clone()) - .unwrap_or(LuaType::Unknown)) - }, - ) - .ok()?; - - // `MissingReturn` currently ignores runtime-dependent divergence if - // a later `return` is still reachable. - if !can_fall_through && !can_break && !is_infinite { + let (can_fall_through, can_break) = analyze_func_body_missing_return_flags_with( + block.clone(), + &mut |expr: &LuaExpr| { + Ok(semantic_model + .infer_expr(expr.clone()) + .unwrap_or(LuaType::Unknown)) + }, + ) + .ok()?; + + // Non-terminating paths satisfy `MissingReturn`; only paths that + // can leave the function body without returning should warn. + if !can_fall_through && !can_break { return Some(()); } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs index 49442418f..d75206147 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs @@ -153,6 +153,22 @@ mod tests { )); } + #[test] + fn test_assert_optional_return_is_not_redundant() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::RedundantReturnValue, + r#" + --- @return string + function foo() + local res --- @type string? + return assert(res) + end + "# + )); + } + #[test] fn test_not_return_anno() { let mut ws = VirtualWorkspace::new(); @@ -290,7 +306,7 @@ mod tests { "# )); - assert!(!ws.has_no_diagnostic( + assert!(ws.has_no_diagnostic( DiagnosticCode::MissingReturn, r#" local A @@ -724,6 +740,25 @@ mod tests { ); } + #[test] + fn test_missing_return_accepts_non_terminating_truthy_while() { + assert_missing_return_ok( + r#" + --- @param ready boolean + --- @return string + function foo(ready) + while true do + if ready then + return 'ready' + end + end + + error('unreachable') + end + "#, + ); + } + #[test] fn test_missing_return_accepts_infinite_repeat_with_break_before_return() { assert_missing_return_ok( @@ -743,8 +778,8 @@ mod tests { } #[test] - fn test_missing_return_rejects_dynamic_while_with_infinite_body_before_return() { - assert_missing_return_error( + fn test_missing_return_accepts_dynamic_while_with_infinite_body_before_return() { + assert_missing_return_ok( r#" ---@return number local function foo(a) @@ -760,8 +795,8 @@ mod tests { } #[test] - fn test_missing_return_rejects_dynamic_while_with_break_or_infinite_body_before_return() { - assert_missing_return_error( + fn test_missing_return_accepts_dynamic_while_with_break_or_infinite_body_before_return() { + assert_missing_return_ok( r#" ---@return number local function foo(a, b) @@ -781,8 +816,8 @@ mod tests { } #[test] - fn test_missing_return_rejects_stalling_numeric_for_before_return() { - assert_missing_return_error( + fn test_missing_return_accepts_non_terminating_numeric_for_before_return() { + assert_missing_return_ok( r#" ---@return number local function foo() @@ -798,8 +833,8 @@ mod tests { } #[test] - fn test_missing_return_rejects_stalling_generic_for_before_return() { - assert_missing_return_error( + fn test_missing_return_accepts_non_terminating_generic_for_before_return() { + assert_missing_return_ok( r#" local function iter(_, done) if done then diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs index e0031023a..f12bcf547 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs @@ -283,6 +283,56 @@ mod tests { )); } + #[test] + fn test_pcall_return_after_type_guard() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + --- @param a string|fun(): integer + --- @return integer? + function foo(a) + if type(a) == 'string' then + return + end + + local ok, result = pcall(a) + if not ok then + return + end + + return result + end + "# + )); + } + + #[test] + fn test_pcall_return_after_type_guard_with_table_arg() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + --- @param cb string|fun(a: {}): integer + --- @return integer? + function foo(cb) + if type(cb) == 'string' then + return + end + + local ok, result = pcall(cb, {}) + if not ok then + return + end + + return result + end + "# + )); + } + #[test] fn test_variadic_return_type_mismatch() { let mut ws = VirtualWorkspace::new(); @@ -689,4 +739,37 @@ mod tests { "# )); } + + #[test] + fn test_asserted_array_member_return_field() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + let code = r#" + --- @return { a: integer } + function foo() + local arr --- @type integer[] + local i --- @type integer? + local a --- @type integer? + i = _ --[[@as integer]] + a = assert(arr[i]) + return { a = a } + end + "#; + assert!(ws.has_no_diagnostic(DiagnosticCode::ReturnTypeMismatch, code)); + assert!(ws.has_no_diagnostic(DiagnosticCode::AssignTypeMismatch, code)); + } + + #[test] + fn test_and_or_function_guard_return() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + --- @param f string|(fun():string) + --- @return string + function foo(f) + return type(f) == 'function' and f() or f + end + "# + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs index f18c99686..cfefece06 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/undefined_field_test.rs @@ -873,4 +873,31 @@ mod test { "# )); } + + #[test] + fn test_array_index_with_integer_literal_union() { + let mut ws = VirtualWorkspace::new(); + let code = r#" + ---@alias IntegerPartIndex + ---| 1 + ---| 2 + + ---@alias NumericPartIndex + ---| 1 + ---| number + + local parts --- @type string[] + local id --- @type 1|2 + local alias_id --- @type IntegerPartIndex + local numeric_id --- @type NumericPartIndex + result = parts[id] + alias_result = parts[alias_id] + numeric_result = parts[numeric_id] + "#; + + assert!(ws.has_no_diagnostic(DiagnosticCode::UndefinedField, code)); + assert_eq!(ws.expr_ty("result"), ws.ty("string?")); + assert_eq!(ws.expr_ty("alias_result"), ws.ty("string?")); + assert_eq!(ws.expr_ty("numeric_result"), ws.ty("string?")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 1c8e39996..4a8eaff7c 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -3,7 +3,7 @@ use emmylua_parser::{LuaCallExpr, LuaExpr}; use hashbrown::HashSet; use std::{ops::Deref, sync::Arc}; -use crate::semantic::infer::infer_expr_list_types; +use crate::semantic::infer::{InferResult, infer_expr_list_types}; use crate::{ DocTypeInferContext, FileId, GenericParam, GenericTplId, LuaFunctionType, LuaGenericType, LuaTypeNode, @@ -202,24 +202,26 @@ pub fn infer_callable_return_from_remaining_args( callable_type: &LuaType, arg_exprs: &[LuaExpr], ) -> Result, InferFailReason> { - if arg_exprs.is_empty() { - return Ok(None); - } - - let call_arg_types = - match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { + let call_arg_types = if arg_exprs.is_empty() { + Vec::new() + } else { + match infer_expr_list_types( + context.db, + context.cache, + arg_exprs, + None, + infer_call_arg_type, + ) { Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), Err(_) => arg_exprs .iter() .map(|arg_expr| { - infer_expr(context.db, context.cache, arg_expr.clone()) + infer_call_arg_type(context.db, context.cache, arg_expr.clone()) .unwrap_or(LuaType::Unknown) }) .collect::>(), - }; - if call_arg_types.is_empty() { - return Ok(None); - } + } + }; // Preserve any known remaining-arg shape, including arity, even when some later arguments // collapse to `unknown`. This avoids unioning returns from overloads that are impossible @@ -227,6 +229,17 @@ pub fn infer_callable_return_from_remaining_args( infer_callable_return_from_arg_types(context, callable_type, &call_arg_types) } +fn infer_call_arg_type(db: &DbIndex, cache: &mut LuaInferCache, arg_expr: LuaExpr) -> InferResult { + if !cache.is_no_flow() || !matches!(&arg_expr, LuaExpr::TableExpr(_)) { + return infer_expr(db, cache, arg_expr); + } + + // Generic call matching stays no-flow, but direct table literal arguments + // are local shapes and do not need flow replay. + let table_exprs = [arg_expr.get_syntax_id()]; + cache.with_replay_overlay(&[], &table_exprs, |cache| infer_expr(db, cache, arg_expr)) +} + fn instantiate_callable_from_arg_types( context: &mut TplContext, callable: &Arc, @@ -485,6 +498,9 @@ fn infer_generic_types_from_call( let mut unresolve_tpls = vec![]; for i in 0..func_params.len() { if i >= arg_exprs.len() { + if let LuaType::Variadic(variadic) = &func_params[i].1 { + variadic_tpl_pattern_match(context, variadic, &[])?; + } break; } @@ -506,7 +522,7 @@ fn infer_generic_types_from_call( continue; } - let arg_type = match infer_expr(db, context.cache, call_arg_expr.clone()) { + let arg_type = match infer_call_arg_type(db, context.cache, call_arg_expr.clone()) { Ok(t) => t, Err(InferFailReason::FieldNotFound) => LuaType::Nil, // 对于未找到的字段, 我们认为是 nil 以执行后续推断 Err(e) => return Err(e), @@ -532,7 +548,7 @@ fn infer_generic_types_from_call( (LuaType::Variadic(variadic), _) => { let mut arg_types = vec![]; for arg_expr in &arg_exprs[i..] { - let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; + let arg_type = infer_call_arg_type(db, context.cache, arg_expr.clone())?; arg_types.push(arg_type); } variadic_tpl_pattern_match(context, variadic, &arg_types)?; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 77900f2d4..467da24e6 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -644,10 +644,7 @@ fn param_type_list_pattern_match_type_list( LuaType::Variadic(inner) => { let i = i + target_offset; if i >= targets.len() { - if let VariadicType::Base(LuaType::TplRef(tpl_ref)) = inner.deref() { - let tpl_id = tpl_ref.get_tpl_id(); - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); - } + variadic_tpl_pattern_match(context, inner, &[])?; break; } @@ -810,7 +807,8 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + // Zero varargs are an empty sequence, not one nil return slot. + context.substitutor.insert_multi_types(tpl_id, Vec::new()); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -820,7 +818,7 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.substitutor.insert_multi_types(tpl_id, Vec::new()); } 1 => { context.substitutor.insert_type( @@ -863,7 +861,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, false); + context.substitutor.insert_multi_types(tpl_id, Vec::new()); } 1 => { context.substitutor.insert_type( diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs index a155ba0e7..85fd20a77 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_index/infer_array.rs @@ -5,7 +5,7 @@ use emmylua_parser::{ use crate::{ DbIndex, InferFailReason, LuaArrayLen, LuaArrayType, LuaInferCache, LuaMemberKey, LuaType, - TypeOps, + TypeOps, check_type_compact, semantic::infer::{infer_index::infer_expr_for_index, narrow::get_var_expr_var_ref_id}, }; @@ -39,8 +39,9 @@ pub(super) fn infer_array_member_by_key( return Ok(array_member_fallback(db, base)); } - if !key_type.is_integer() { - if key_type.is_number() { + let is_integer_key = key_type.is_integer() || key_type_matches(db, &LuaType::Integer, key_type); + if !is_integer_key { + if key_type.is_number() || key_type_matches(db, &LuaType::Number, key_type) { return Ok(array_member_fallback(db, base)); } @@ -64,6 +65,17 @@ pub(super) fn infer_array_member_by_key( Ok(array_member_fallback(db, base)) } +fn key_type_matches(db: &DbIndex, expected: &LuaType, actual: &LuaType) -> bool { + !matches!( + actual, + LuaType::Any + | LuaType::Unknown + | LuaType::TplRef(_) + | LuaType::StrTplRef(_) + | LuaType::ConstTplRef(_) + ) && check_type_compact(db, expected, actual).is_ok() +} + pub(super) fn array_member_fallback(db: &DbIndex, base: &LuaType) -> LuaType { match base { LuaType::Any | LuaType::Unknown => base.clone(), diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index 27f0b4753..f084809ce 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -65,6 +65,9 @@ pub(in crate::semantic) enum ExprTypeContinuation { call_expr: LuaCallExpr, condition_flow: InferConditionFlow, }, + Truthiness { + condition_flow: InferConditionFlow, + }, ArrayLen { subquery_condition_flow: InferConditionFlow, max_adjustment: i64, @@ -582,14 +585,24 @@ pub(super) fn get_type_at_condition_flow( continue; } LuaExpr::CallExpr(call_expr) => { - return get_type_at_call_expr( + let action = get_type_at_call_expr( db, cache, var_ref_id, flow_node, - call_expr, + call_expr.clone(), condition_flow, - ); + )?; + if !matches!(action, ConditionFlowAction::Continue) { + return Ok(action); + } + + let antecedent_flow_id = get_single_antecedent(flow_node)?; + return Ok(ConditionFlowAction::NeedExprType { + flow_id: antecedent_flow_id, + expr: LuaExpr::CallExpr(call_expr), + resume: ExprTypeContinuation::Truthiness { condition_flow }, + }); } LuaExpr::IndexExpr(index_expr) => { return get_type_at_index_expr(db, cache, var_ref_id, index_expr, condition_flow); @@ -699,6 +712,16 @@ pub(in crate::semantic::infer::narrow) fn resolve_expr_type_continuation( call_expr, condition_flow, ), + ExprTypeContinuation::Truthiness { condition_flow } => Ok(match condition_flow { + _ if expr_type.is_never() => ConditionFlowAction::Result(LuaType::Never), + InferConditionFlow::TrueCondition if expr_type.is_always_falsy() => { + ConditionFlowAction::Result(LuaType::Never) + } + InferConditionFlow::FalseCondition if expr_type.is_always_truthy() => { + ConditionFlowAction::Result(LuaType::Never) + } + _ => ConditionFlowAction::Continue, + }), ExprTypeContinuation::ArrayLen { subquery_condition_flow, max_adjustment, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 9488b8128..a29a511b3 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -299,6 +299,33 @@ fn collect_expr_dependency_queries( return; } + if let LuaExpr::CallExpr(call_expr) = expr { + // Call arguments live under LuaCallArgList, not as direct LuaExpr children. + if let Some(prefix_expr) = call_expr.get_prefix_expr() { + collect_expr_dependency_queries( + db, + tree, + cache, + fallback_flow_id, + &prefix_expr, + dependency_queries, + ); + } + if let Some(arg_list) = call_expr.get_args_list() { + for arg in arg_list.get_args() { + collect_expr_dependency_queries( + db, + tree, + cache, + fallback_flow_id, + &arg, + dependency_queries, + ); + } + } + return; + } + if let LuaExpr::IndexExpr(index_expr) = expr { // A resolved IndexRef overlay lets replay short-circuit the whole // expression; if it fails, prefix/key overlays are still available.