diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs index 1653f0ca8..75442caad 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/field_or_operator_def_tags.rs @@ -110,7 +110,7 @@ pub fn analyze_field(analyzer: &mut DocAnalyzer, tag: LuaDocTagField) -> Option< ), ("key".to_string(), Some(key_type_ref.clone())), ], - field_type.clone(), + vec![field_type.clone()], ))), ); analyzer @@ -207,7 +207,7 @@ pub fn analyze_operator(analyzer: &mut DocAnalyzer, tag: LuaDocTagOperator) -> O false, false, operands, - return_type, + vec![return_type], ))), ); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs index fe2ff27d1..ec8cea4dd 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/doc/infer_type.rs @@ -696,21 +696,13 @@ fn infer_func_type(analyzer: &mut DocTypeAnalyzeContext<'_>, func: &LuaDocFuncTy is_colon = false } - let return_type = if return_types.len() == 1 { - return_types[0].clone() - } else if return_types.len() > 1 { - LuaType::Variadic(VariadicType::Multi(return_types).into()) - } else { - LuaType::Nil - }; - LuaType::DocFunction( LuaFunctionType::new( async_state, is_colon, is_variadic, params_result, - return_type, + return_types, ) .into(), ) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs index f8e2371f0..eb9fc1fe1 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/closure.rs @@ -1,17 +1,17 @@ -use std::ops::Deref; - use emmylua_parser::{ - LuaAst, LuaAstNode, LuaCallArgList, LuaCallExpr, LuaClosureExpr, LuaFuncStat, LuaVarExpr, + LuaAst, LuaAstNode, LuaCallArgList, LuaCallExpr, LuaClosureExpr, LuaExpr, LuaFuncStat, + LuaVarExpr, }; use crate::{ - DbIndex, InferFailReason, LuaInferCache, LuaType, SignatureReturnStatus, TypeOps, VariadicType, + DbIndex, InferFailReason, LuaInferCache, LuaType, SignatureReturnStatus, TypeOps, compilation::analyzer::unresolve::{ UnResolveCallClosureParams, UnResolveClosureReturn, UnResolveParentAst, UnResolveParentClosureParams, UnResolveReturn, }, - db_index::{LuaDocReturnInfo, LuaSignatureId}, + db_index::{LuaDocReturnInfo, LuaSignatureId, return_row::merge_return_rows_with}, infer_expr, + semantic::infer_return_expr_list_types, }; use super::{LuaAnalyzer, LuaReturnPoint, analyze_func_body_returns_with}; @@ -220,120 +220,53 @@ pub fn analyze_return_point( cache: &mut LuaInferCache, return_points: &[LuaReturnPoint], ) -> Result, InferFailReason> { - let mut return_type = None; + let mut return_row: Option> = None; for point in return_points { - let point_type = match point { - LuaReturnPoint::Expr(expr) => Some(infer_expr(db, cache, expr.clone())?), - LuaReturnPoint::MuliExpr(exprs) => { - let mut multi_return = Vec::with_capacity(exprs.len()); - for expr in exprs { - multi_return.push(infer_expr(db, cache, expr.clone())?); - } - Some(LuaType::Variadic(VariadicType::Multi(multi_return).into())) + let point_row = match point { + LuaReturnPoint::Expr(expr) => { + Some(infer_return_row(db, cache, std::slice::from_ref(expr))?) } - LuaReturnPoint::Nil => Some(LuaType::Nil), + LuaReturnPoint::MuliExpr(exprs) => Some(infer_return_row(db, cache, exprs)?), + LuaReturnPoint::Empty => Some(Vec::new()), _ => None, }; - if let Some(point_type) = point_type { - return_type = Some(match return_type { - Some(return_type) => union_return_expr(db, return_type, point_type), - None => point_type, + if let Some(point_row) = point_row { + return_row = Some(match return_row { + Some(return_row) => { + let rows = [return_row.as_slice(), point_row.as_slice()]; + merge_return_rows_with(&rows, |types| { + types + .into_iter() + .reduce(|left, right| TypeOps::Union.apply(db, &left, &right)) + .unwrap_or(LuaType::Never) + }) + } + None => point_row, }); } } - let return_type = return_type.unwrap_or(LuaType::Unknown); - - Ok(vec![LuaDocReturnInfo { - type_ref: return_type, - description: None, - name: None, - attributes: None, - }]) + let return_row = return_row.unwrap_or_else(|| vec![LuaType::Unknown]); + + Ok(return_row + .into_iter() + .map(|type_ref| LuaDocReturnInfo { + type_ref, + description: None, + name: None, + attributes: None, + }) + .collect()) } -fn union_return_expr(db: &DbIndex, left: LuaType, right: LuaType) -> LuaType { - match (&left, &right) { - (LuaType::Variadic(left_variadic), LuaType::Variadic(right_variadic)) => { - match (&left_variadic.deref(), &right_variadic.deref()) { - (VariadicType::Base(left_base), VariadicType::Base(right_base)) => { - let union_base = TypeOps::Union.apply(db, left_base, right_base); - LuaType::Variadic(VariadicType::Base(union_base).into()) - } - (VariadicType::Multi(left_multi), VariadicType::Multi(right_multi)) => { - let mut new_multi = vec![]; - let max_len = left_multi.len().max(right_multi.len()); - for i in 0..max_len { - let left_type = left_multi.get(i).cloned().unwrap_or(LuaType::Nil); - let right_type = right_multi.get(i).cloned().unwrap_or(LuaType::Nil); - new_multi.push(TypeOps::Union.apply(db, &left_type, &right_type)); - } - LuaType::Variadic(VariadicType::Multi(new_multi).into()) - } - // difficult to merge the type, use let - _ => left.clone(), - } - } - (LuaType::Variadic(variadic), _) => { - let first_type = variadic.get_type(0).cloned().unwrap_or(LuaType::Unknown); - let first_union_type = TypeOps::Union.apply(db, &first_type, &right); - - match variadic.deref() { - VariadicType::Base(base) => { - let union_base = TypeOps::Union.apply(db, base, &LuaType::Nil); - LuaType::Variadic( - VariadicType::Multi(vec![ - first_union_type, - LuaType::Variadic(VariadicType::Base(union_base).into()), - ]) - .into(), - ) - } - VariadicType::Multi(multi) => { - let mut new_multi = multi.clone(); - if !new_multi.is_empty() { - new_multi[0] = first_union_type; - for mult in new_multi.iter_mut().skip(1) { - *mult = TypeOps::Union.apply(db, mult, &LuaType::Nil); - } - } else { - new_multi.push(first_union_type); - } - - LuaType::Variadic(VariadicType::Multi(new_multi).into()) - } - } - } - (_, LuaType::Variadic(variadic)) => { - let first_type = variadic.get_type(0).cloned().unwrap_or(LuaType::Unknown); - let first_union_type = TypeOps::Union.apply(db, &left, &first_type); - match variadic.deref() { - VariadicType::Base(base) => { - let union_base = TypeOps::Union.apply(db, base, &LuaType::Nil); - LuaType::Variadic( - VariadicType::Multi(vec![ - first_union_type, - LuaType::Variadic(VariadicType::Base(union_base).into()), - ]) - .into(), - ) - } - VariadicType::Multi(multi) => { - let mut new_multi = multi.clone(); - if !new_multi.is_empty() { - new_multi[0] = first_union_type; - for mult in new_multi.iter_mut().skip(1) { - *mult = TypeOps::Union.apply(db, mult, &LuaType::Nil); - } - } else { - new_multi.push(first_union_type); - } - - LuaType::Variadic(VariadicType::Multi(new_multi).into()) - } - } - } - _ => TypeOps::Union.apply(db, &left, &right), - } +fn infer_return_row( + db: &DbIndex, + cache: &mut LuaInferCache, + exprs: &[LuaExpr], +) -> Result, InferFailReason> { + Ok(infer_return_expr_list_types(db, cache, exprs, infer_expr)? + .into_iter() + .map(|(ty, _)| ty) + .collect()) } diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs index 6425ad89e..8a6e8293b 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/func_body.rs @@ -9,7 +9,7 @@ use crate::{InferFailReason, LuaType}; pub enum LuaReturnPoint { Expr(LuaExpr), MuliExpr(Vec), - Nil, + Empty, Error, } @@ -62,7 +62,7 @@ where { let mut flow = analyze_block_returns(body, infer_expr_type)?; if flow.can_fall_through || flow.can_break { - flow.return_points.push(LuaReturnPoint::Nil); + flow.return_points.push(LuaReturnPoint::Empty); } Ok(flow.return_points) @@ -356,7 +356,7 @@ fn analyze_call_expr_stat_returns( fn analyze_return_stat_returns(return_stat: LuaReturnStat) -> ReturnFlow { let exprs: Vec = return_stat.get_expr_list().collect(); let return_point = match exprs.len() { - 0 => LuaReturnPoint::Nil, + 0 => LuaReturnPoint::Empty, 1 => LuaReturnPoint::Expr(exprs[0].clone()), _ => LuaReturnPoint::MuliExpr(exprs), }; diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs index 2677dc457..0b56f81ce 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/module.rs @@ -3,6 +3,7 @@ use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr}; use crate::{ InferFailReason, LuaDeclId, LuaSemanticDeclId, LuaSignatureId, compilation::analyzer::unresolve::UnResolveModule, db_index::LuaType, infer_expr, + semantic::adjusted_result_slot_type, }; use super::{LuaAnalyzer, LuaReturnPoint, analyze_func_body_returns_with}; @@ -46,7 +47,7 @@ pub fn analyze_chunk_return(analyzer: &mut LuaAnalyzer, chunk: LuaChunk) -> Opti .db .get_module_index_mut() .get_module_mut(analyzer.file_id)?; - module_info.export_type = Some(expr_type.get_result_slot_type(0).unwrap_or(expr_type)); + module_info.export_type = Some(adjusted_result_slot_type(&expr_type, 0)); module_info.semantic_id = semantic_id; if let Some(visibility) = visibility { module_info.merge_visibility(visibility); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs index 51d7983fd..f805c5b80 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/lua/stats.rs @@ -12,6 +12,7 @@ use crate::{ unresolve::{UnResolveDecl, UnResolveMember}, }, db_index::{LuaDeclId, LuaMember, LuaMemberFeature, LuaMemberId, LuaMemberOwner, LuaType}, + semantic::{adjusted_result_slot_type, assignment_rhs_source}, }; use super::LuaAnalyzer; @@ -49,7 +50,7 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) match analyzer.infer_expr(&expr) { Ok(expr_type) => { - let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); + let expr_type = adjusted_result_slot_type(&expr_type, 0); let decl_id = LuaDeclId::new(analyzer.file_id, position); // 当`call`参数包含表时, 表可能未被分析, 需要延迟 if let LuaType::Instance(instance) = &expr_type @@ -105,26 +106,17 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) if let Some(last_expr) = last_expr { match analyzer.infer_expr(last_expr) { Ok(last_expr_type) => { - if last_expr_type.contain_multi_return() { - for i in expr_count..name_count { - let name = name_list.get(i)?; - let position = name.get_position(); - let decl_id = LuaDeclId::new(analyzer.file_id, position); - let ret_type = last_expr_type.get_result_slot_type(i - expr_count + 1); - if let Some(ret_type) = ret_type { - bind_type( - analyzer.db, - decl_id.into(), - LuaTypeCache::InferType(ret_type.clone()), - ); - } else { - analyzer.db.get_type_index_mut().bind_type( - decl_id.into(), - LuaTypeCache::InferType(LuaType::Unknown), - ); - } - } - return Some(()); + for i in expr_count..name_count { + let name = name_list.get(i)?; + let position = name.get_position(); + let decl_id = LuaDeclId::new(analyzer.file_id, position); + let (_, slot) = assignment_rhs_source(expr_count, i)?; + let ret_type = adjusted_result_slot_type(&last_expr_type, slot); + bind_type( + analyzer.db, + decl_id.into(), + LuaTypeCache::InferType(ret_type), + ); } } Err(reason) => { @@ -132,11 +124,12 @@ pub fn analyze_local_stat(analyzer: &mut LuaAnalyzer, local_stat: LuaLocalStat) let name = name_list.get(i)?; let position = name.get_position(); let decl_id = LuaDeclId::new(analyzer.file_id, position); + let (_, slot) = assignment_rhs_source(expr_count, i)?; let unresolve = UnResolveDecl { file_id: analyzer.file_id, decl_id, expr: last_expr.clone(), - ret_idx: i - expr_count + 1, + ret_idx: slot, }; analyzer @@ -310,7 +303,7 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta } let expr_type = match analyzer.infer_expr(expr) { - Ok(expr_type) => expr_type.get_result_slot_type(0).unwrap_or(expr_type), + Ok(expr_type) => adjusted_result_slot_type(&expr_type, 0), Err(InferFailReason::None) => LuaType::Unknown, Err(reason) => { match type_owner { @@ -363,18 +356,17 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta { match analyzer.infer_expr(last_expr) { Ok(last_expr_type) => { - if last_expr_type.contain_multi_return() { - for i in expr_count..var_count { - let var = var_list.get(i)?; - let type_owner = get_var_owner(analyzer, var.clone()); - set_index_expr_owner(analyzer, var.clone()); - assign_merge_type_owner_and_expr_type( - analyzer, - type_owner, - &last_expr_type, - i - expr_count + 1, - ); - } + for i in expr_count..var_count { + let var = var_list.get(i)?; + let type_owner = get_var_owner(analyzer, var.clone()); + set_index_expr_owner(analyzer, var.clone()); + let (_, slot) = assignment_rhs_source(expr_count, i)?; + assign_merge_type_owner_and_expr_type( + analyzer, + type_owner, + &last_expr_type, + slot, + ); } } Err(_) => { @@ -382,11 +374,12 @@ pub fn analyze_assign_stat(analyzer: &mut LuaAnalyzer, assign_stat: LuaAssignSta let var = var_list.get(i)?; let type_owner = get_var_owner(analyzer, var.clone()); set_index_expr_owner(analyzer, var.clone()); + let (_, slot) = assignment_rhs_source(expr_count, i)?; merge_type_owner_and_unresolve_expr( analyzer, type_owner, last_expr.clone(), - i - expr_count + 1, + slot, ); } } @@ -404,7 +397,7 @@ fn assign_merge_type_owner_and_expr_type( expr_type: &LuaType, idx: usize, ) -> Option<()> { - let expr_type = expr_type.get_result_slot_type(idx).unwrap_or(LuaType::Nil); + let expr_type = adjusted_result_slot_type(expr_type, idx); bind_type(analyzer.db, type_owner, LuaTypeCache::InferType(expr_type)); diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs index dfb1e9ea0..c50863d09 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve.rs @@ -18,7 +18,7 @@ use crate::{ }, db_index::{DbIndex, LuaMemberOwner, LuaType}, find_members_with_key, - semantic::{LuaInferCache, infer_expr}, + semantic::{LuaInferCache, adjusted_result_slot_type, infer_expr}, }; use super::{ @@ -34,9 +34,7 @@ pub fn try_resolve_decl( let expr = decl.expr.clone(); let expr_type = infer_expr(db, cache, expr)?; let decl_id = decl.decl_id; - let expr_type = expr_type - .get_result_slot_type(decl.ret_idx) - .unwrap_or(LuaType::Unknown); + let expr_type = adjusted_result_slot_type(&expr_type, decl.ret_idx); bind_type(db, decl_id.into(), LuaTypeCache::InferType(expr_type)); Ok(()) @@ -75,9 +73,7 @@ pub fn try_resolve_member( if let Some(expr) = unresolve_member.expr.clone() { let expr_type = infer_expr(db, cache, expr)?; - let expr_type = expr_type - .get_result_slot_type(unresolve_member.ret_idx) - .unwrap_or(LuaType::Unknown); + let expr_type = adjusted_result_slot_type(&expr_type, unresolve_member.ret_idx); let member_id = unresolve_member.member_id; bind_type(db, member_id.into(), LuaTypeCache::InferType(expr_type)); @@ -169,7 +165,7 @@ pub fn try_resolve_module( ) -> ResolveResult { let expr = module.expr.clone(); let expr_type = infer_expr(db, cache, expr)?; - let expr_type = expr_type.get_result_slot_type(0).unwrap_or(expr_type); + let expr_type = adjusted_result_slot_type(&expr_type, 0); let module_info = db .get_module_index_mut() .get_module_mut(module.file_id) diff --git a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs index 120588356..734ce2455 100644 --- a/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs +++ b/crates/emmylua_code_analysis/src/compilation/analyzer/unresolve/resolve_closure.rs @@ -5,7 +5,8 @@ use emmylua_parser::{LuaAstNode, LuaIndexMemberExpr, LuaTableExpr, LuaVarExpr}; use crate::{ DbIndex, InferFailReason, InferGuard, InferGuardRef, LuaDocParamInfo, LuaDocReturnInfo, LuaFunctionType, LuaInferCache, LuaSignature, LuaType, SignatureReturnStatus, TypeOps, - get_real_type, infer_call_expr_func, infer_expr, infer_table_should_be, + db_index::return_row::merge_return_rows, get_real_type, infer_call_expr_func, infer_expr, + infer_table_should_be, }; use super::{ @@ -137,13 +138,13 @@ pub fn try_resolve_closure_return( _ => {} } - let ret_type = if let Some(param_type) = call_doc_func.get_params().get(param_idx) { + let ret_types = if let Some(param_type) = call_doc_func.get_params().get(param_idx) { let Some(param_type) = get_real_type(db, param_type.1.as_ref().unwrap_or(&LuaType::Any)) else { return Ok(()); }; if let LuaType::DocFunction(func) = param_type { - func.get_ret().clone() + func.get_return_row().to_vec() } else { return Ok(()); } @@ -156,7 +157,7 @@ pub fn try_resolve_closure_return( .get_mut(&closure_return.signature_id) .ok_or(InferFailReason::None)?; - if ret_type.contain_tpl() { + if ret_types.iter().any(|ty| ty.contain_tpl()) { return try_convert_to_func_body_infer(db, cache, closure_return); } @@ -168,12 +169,19 @@ pub fn try_resolve_closure_return( _ => return Ok(()), } - signature.return_docs.push(LuaDocReturnInfo { - name: None, - type_ref: ret_type.clone(), - description: None, - attributes: None, - }); + if ret_types.is_empty() { + signature.resolve_return = SignatureReturnStatus::DocResolve; + return Ok(()); + } + + signature + .return_docs + .extend(ret_types.into_iter().map(|type_ref| LuaDocReturnInfo { + name: None, + type_ref, + description: None, + attributes: None, + })); signature.resolve_return = SignatureReturnStatus::DocResolve; Ok(()) @@ -307,8 +315,6 @@ fn resolve_closure_member_type( .get(&closure_params.signature_id) .ok_or(InferFailReason::None)?; let mut final_params = signature.get_type_params().to_vec(); - let mut final_ret = LuaType::Never; - let mut has_final_ret = false; let mut multi_function_type = Vec::new(); for typ in union_types.into_vec() { @@ -336,7 +342,7 @@ fn resolve_closure_member_type( } let mut variadic_type = LuaType::Unknown; - for doc_func in multi_function_type { + for doc_func in &multi_function_type { let mut doc_params = doc_func.get_params().to_vec(); match (doc_func.is_colon_define(), signature.is_colon_define) { (true, false) => { @@ -383,9 +389,6 @@ fn resolve_closure_member_type( final_params.push((param.0.clone(), param.1.clone())); } } - - has_final_ret = true; - final_ret = TypeOps::Union.apply(db, &final_ret, doc_func.get_ret()); } if !variadic_type.is_unknown() @@ -394,11 +397,15 @@ fn resolve_closure_member_type( param.1 = Some(variadic_type); } - let final_ret = if !has_final_ret { - LuaType::Unknown - } else { - final_ret - }; + if multi_function_type.is_empty() { + return Ok(()); + } + + let return_rows = multi_function_type + .iter() + .map(|doc_func| doc_func.get_return_row()) + .collect::>(); + let final_ret = merge_return_rows(&return_rows); resolve_doc_function( db, @@ -491,12 +498,24 @@ fn resolve_doc_function( { signature.resolve_return = SignatureReturnStatus::DocResolve; signature.return_docs.clear(); - signature.return_docs.push(LuaDocReturnInfo { - name: None, - type_ref: doc_func.get_ret().clone(), - description: None, - attributes: None, - }); + if doc_func.get_return_row().is_empty() { + return Ok(()); + } + + signature + .return_docs + .extend( + doc_func + .get_return_row() + .iter() + .cloned() + .map(|type_ref| LuaDocReturnInfo { + name: None, + type_ref, + description: None, + attributes: None, + }), + ); } Ok(()) diff --git a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs index f5a6e3303..9849b04a4 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/callable_return_infer_test.rs @@ -1,6 +1,6 @@ #[cfg(test)] mod test { - use crate::VirtualWorkspace; + use crate::{LuaType, VirtualWorkspace}; #[test] fn test_higher_order_generic_return_infer() { @@ -29,6 +29,310 @@ mod test { assert_eq!(ws.expr_ty("payload"), ws.ty("integer")); } + #[test] + fn test_higher_order_empty_variadic_return_tail_assignment_pads_nil() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) + end + + ---@type fun() + local none + + ok, payload = wrap(none) + "#, + ); + + assert_eq!(ws.expr_ty("ok"), ws.ty("boolean")); + assert_eq!(ws.expr_ty("payload"), ws.ty("nil")); + } + + #[test] + fn test_higher_order_fixed_missing_return_slot_infers_nil() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, B + ---@param f fun(): A, B + ---@return A, B + local function call(f) + end + + ---@return string + local function one() + end + + first, second = call(one) + "#, + ); + + assert_eq!(ws.expr_ty("first"), ws.ty("string")); + assert_eq!(ws.expr_ty("second"), ws.ty("nil")); + } + + #[test] + fn test_higher_order_empty_variadic_return_tail_keeps_call_arity() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T, R + ---@param f fun(...: T...): R... + ---@param ... T... + ---@return boolean, R... + local function wrap(f, ...) + end + + ---@type fun() + local none + + ---@overload fun(x: boolean): "one" + ---@overload fun(x: boolean, y: nil): "two" + ---@param ... unknown + ---@return "fallback" + local function arity(...) + end + + which = arity(wrap(none)) + "#, + ); + + let which = ws.expr_ty("which"); + assert_eq!(ws.humanize_type(which), "\"one\""); + } + + #[test] + fn test_higher_order_empty_callback_return_with_remaining_args_keeps_zero_arity() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, R + ---@param f fun(x: A): R... + ---@param x A + ---@return R... + local function apply(f, x) + end + + ---@return + local function none(x) + end + + ---@overload fun(): "zero" + ---@overload fun(x: nil): "one" + ---@param ... unknown + ---@return "fallback" + local function arity(...) + end + + which = arity(apply(none, "x")) + "#, + ); + + let which = ws.expr_ty("which"); + assert_eq!(ws.humanize_type(which), "\"zero\""); + } + + #[test] + fn test_empty_return_call_keeps_argument_arity() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return + local function none() + end + + ---@overload fun(): "zero" + ---@overload fun(x: nil): "one" + ---@param ... unknown + ---@return "fallback" + local function arity(...) + end + + which = arity(none()) + "#, + ); + + let which = ws.expr_ty("which"); + assert_eq!(ws.humanize_type(which), "\"zero\""); + } + + #[test] + fn test_inferred_empty_return_call_keeps_argument_arity() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + local function none() + return + end + + ---@overload fun(): "zero" + ---@overload fun(x: nil): "one" + ---@param ... unknown + ---@return "fallback" + local function arity(...) + end + + which = arity(none()) + "#, + ); + + let which = ws.expr_ty("which"); + assert_eq!(ws.humanize_type(which), "\"zero\""); + } + + #[test] + fn test_empty_return_call_is_nil_in_scalar_context() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return + local function none() + end + + value = none() or 1 + "#, + ); + + assert_eq!(ws.expr_ty("value"), LuaType::IntegerConst(1)); + } + + #[test] + fn test_non_tail_empty_return_call_pads_nil_argument() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return + local function none() + end + + ---@overload fun(x: nil, y: integer): "two" + ---@overload fun(x: integer): "one" + ---@param ... unknown + ---@return "fallback" + local function arity(...) + end + + which = arity(none(), 1) + "#, + ); + + let which = ws.expr_ty("which"); + assert_eq!(ws.humanize_type(which), "\"two\""); + } + + #[test] + fn test_variadic_generic_empty_tail_arg_keeps_zero_arity() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic T + ---@param ... T... + ---@return T... + local function pass(...) + end + + ---@return + local function none() + end + + ---@overload fun(): "zero" + ---@overload fun(x: nil): "one" + ---@param ... unknown + ---@return "fallback" + local function arity(...) + end + + which = arity(pass(none())) + "#, + ); + + let which = ws.expr_ty("which"); + assert_eq!(ws.humanize_type(which), "\"zero\""); + } + + #[test] + fn test_union_generic_pack_return_keeps_deep_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@type (fun(...: T...): T...) | nil + local run + + first, second, third = run(1, "x", true) + "#, + ); + + assert_eq!(ws.expr_ty("first"), ws.ty("integer")); + assert_eq!(ws.expr_ty("second"), ws.ty("string")); + assert_eq!(ws.expr_ty("third"), ws.ty("boolean")); + } + + #[test] + fn test_inferred_callback_return_keeps_multiple_slots() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@generic A, B + ---@param f fun(): A, B + ---@return A, B + local function call(f) + end + + local function two() + return 1, "x" + end + + first, second = call(two) + "#, + ); + + assert_eq!(ws.expr_ty("first"), ws.ty("integer")); + assert_eq!(ws.expr_ty("second"), ws.ty("string")); + } + + #[test] + fn test_inferred_return_preserves_unbounded_vararg_tail() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@param ... string + local function pass(...) + return ... + end + + first, second = pass("x", "y") + "#, + ); + + assert_eq!(ws.expr_ty("first"), ws.ty("string")); + assert_eq!(ws.expr_ty("second"), ws.ty("string")); + } + + #[test] + fn test_inferred_return_preserves_unbounded_tail_call() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return string... + local function many() + end + + local function pass() + return many() + end + + first, second = pass() + "#, + ); + + assert_eq!(ws.expr_ty("first"), ws.ty("string")); + assert_eq!(ws.expr_ty("second"), ws.ty("string")); + } + #[test] fn test_higher_order_return_infer_keeps_concrete_callable_result() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs index 086279db4..93c460046 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/generic_test.rs @@ -534,7 +534,7 @@ mod test { "#, )); assert_eq!(ws.expr_ty("G"), ws.ty("integer")); - assert_eq!(ws.expr_ty("H"), ws.ty("any")); + assert_eq!(ws.expr_ty("H"), ws.ty("nil")); } { diff --git a/crates/emmylua_code_analysis/src/db_index/mod.rs b/crates/emmylua_code_analysis/src/db_index/mod.rs index ebcd2284f..3c88c3d44 100644 --- a/crates/emmylua_code_analysis/src/db_index/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/mod.rs @@ -9,6 +9,7 @@ mod module; mod operators; mod property; mod reference; +pub(crate) mod return_row; mod schema; mod semantic_decl; mod signature; diff --git a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs index 6cfcb9ff1..dafa95619 100644 --- a/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs +++ b/crates/emmylua_code_analysis/src/db_index/operators/lua_operator.rs @@ -5,7 +5,7 @@ use rowan::{TextRange, TextSize}; use crate::{ DbIndex, FileId, InFiled, InferFailReason, LuaConstructorReturnMode, LuaFunctionType, LuaSignature, LuaSignatureId, SignatureReturnStatus, - db_index::{LuaType, LuaTypeDeclId}, + db_index::{LuaType, LuaTypeDeclId, return_row::return_type_to_row}, }; use super::lua_operator_meta_method::LuaOperatorMetaMethod; @@ -83,7 +83,7 @@ impl LuaOperator { pub fn get_result(&self, db: &DbIndex) -> Result { match &self.func { - OperatorFunction::Func(func) => Ok(func.get_ret().clone()), + OperatorFunction::Func(func) => Ok(func.get_return_type()), OperatorFunction::Signature(signature_id) => { let signature = db.get_signature_index().get(signature_id); if let Some(signature) = signature { @@ -104,8 +104,7 @@ impl LuaOperator { } => { let signature = db.get_signature_index().get(id); if let Some(signature) = signature { - let return_type = get_constructor_return_type(signature, return_mode); - return Ok(return_type); + return Ok(get_constructor_return_type(signature, return_mode)); } Ok(LuaType::Any) @@ -142,7 +141,7 @@ impl LuaOperator { is_colon_define, signature.is_vararg, params, - return_type, + return_type_to_row(return_type), ); return LuaType::DocFunction(Arc::new(func_type)); } diff --git a/crates/emmylua_code_analysis/src/db_index/return_row.rs b/crates/emmylua_code_analysis/src/db_index/return_row.rs new file mode 100644 index 000000000..37ac9bec5 --- /dev/null +++ b/crates/emmylua_code_analysis/src/db_index/return_row.rs @@ -0,0 +1,236 @@ +use std::sync::Arc; + +use crate::{ + LuaAliasCallKind, LuaAliasCallType, LuaType, VariadicType, db_index::union_type_shallow, +}; + +pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType { + get_overload_row_slot_if_present(row, idx).unwrap_or(LuaType::Nil) +} + +pub(crate) fn row_to_return_type(mut row: Vec) -> LuaType { + match row.len() { + 0 => LuaType::Nil, + 1 => row.pop().unwrap_or(LuaType::Nil), + _ => LuaType::Variadic(VariadicType::Multi(row).into()), + } +} + +/// Convert a row while preserving call-result arity: no values is not scalar nil. +pub(crate) fn row_to_multi_return_type(row: Vec) -> LuaType { + if row.is_empty() { + LuaType::Variadic(VariadicType::Multi(Vec::new()).into()) + } else { + row_to_return_type(row) + } +} + +pub(crate) fn return_type_to_row(return_type: LuaType) -> Vec { + match return_type { + LuaType::Variadic(variadic) => match variadic.as_ref() { + VariadicType::Multi(types) => types.clone(), + VariadicType::Base(_) => vec![LuaType::Variadic(variadic)], + }, + typ => vec![typ], + } +} + +/// Minimum number of values a documented return row requires. +/// +/// Explicit `nil` slots still count for arity. They only collapse to +/// `LuaType::Nil` when a row is used as a single expression type. +pub(crate) fn return_row_min_len(row: &[LuaType]) -> Option { + let mut min_len = match row.last() { + None => 0, + Some(LuaType::Variadic(variadic)) => row.len() - 1 + variadic.get_min_len().unwrap_or(0), + Some(_) => row.len(), + }; + + for idx in (0..min_len).rev() { + let ty = get_overload_row_slot_if_present(row, idx)?; + if matches!(ty, LuaType::Nil) { + break; + } + if ty.is_optional() { + min_len -= 1; + } else { + break; + } + } + + Some(min_len) +} + +/// Maximum number of values a documented return row can produce. +pub(crate) fn return_row_max_len(row: &[LuaType]) -> Option { + match row.last() { + None => Some(0), + Some(LuaType::Variadic(variadic)) => variadic.get_max_len().map(|len| row.len() - 1 + len), + Some(_) => Some(row.len()), + } +} + +pub(crate) fn merge_return_rows(rows: &[&[LuaType]]) -> Vec { + merge_return_rows_with(rows, LuaType::from_vec) +} + +pub(crate) fn merge_return_rows_shallow(rows: &[&[LuaType]]) -> Vec { + merge_return_rows_with(rows, |types| { + types + .into_iter() + .reduce(union_type_shallow) + .unwrap_or(LuaType::Never) + }) +} + +/// Merges return rows using Lua result-slot adjustment. +/// +/// The caller supplies only the slot merge policy. Row shape decisions stay +/// here: missing slots become `nil`, finite variadic tails are expanded, and +/// unbounded tails keep one representative variadic slot. +pub(crate) fn merge_return_rows_with( + rows: &[&[LuaType]], + merge_slot_types: impl Fn(Vec) -> LuaType, +) -> Vec { + let Some(prefix_max_len) = rows.iter().map(|row| row_merge_prefix_len(row)).max() else { + return Vec::new(); + }; + if prefix_max_len == 0 { + return Vec::new(); + } + + let (has_unbounded_variadic_tail, has_tpl_unbounded_variadic_tail) = + rows.iter() + .fold((false, false), |(has_unbounded, has_tpl_unbounded), row| { + let Some(last) = row.last() else { + return (has_unbounded, has_tpl_unbounded); + }; + let LuaType::Variadic(variadic) = last else { + return (has_unbounded, has_tpl_unbounded); + }; + + let has_unbounded_row = variadic.get_max_len().is_none(); + ( + has_unbounded || has_unbounded_row, + has_tpl_unbounded || (has_unbounded_row && variadic.contain_tpl()), + ) + }); + let merge_len = if has_unbounded_variadic_tail { + prefix_max_len + 1 + } else { + prefix_max_len + }; + + let mut types = Vec::with_capacity(merge_len); + for idx in 0..merge_len { + let slot_types = rows + .iter() + .map(|row| get_overload_row_slot_if_present(row, idx).unwrap_or(LuaType::Nil)) + .collect::>(); + types.push(merge_slot_types(slot_types)); + } + if has_unbounded_variadic_tail + && !has_tpl_unbounded_variadic_tail + && let Some(last) = types.last_mut() + && !matches!(last, LuaType::Variadic(_)) + { + *last = LuaType::Variadic(VariadicType::Base(last.clone()).into()); + } + + types +} + +fn row_merge_prefix_len(row: &[LuaType]) -> usize { + let Some(last) = row.last() else { + return 0; + }; + + if let LuaType::Variadic(variadic) = last { + row.len() - 1 + variadic_merge_prefix_len(variadic) + } else { + row.len() + } +} + +fn variadic_merge_prefix_len(variadic: &VariadicType) -> usize { + if let Some(len) = variadic.get_max_len() { + return len; + } + + match variadic { + VariadicType::Base(_) => 1, + VariadicType::Multi(types) => match types.last() { + Some(LuaType::Variadic(variadic)) => { + types.len() - 1 + variadic_merge_prefix_len(variadic) + } + Some(_) => types.len(), + None => 0, + }, + } +} + +fn overload_row_tpl_slot( + call_kind: LuaAliasCallKind, + variadic: &Arc, + index: i64, +) -> LuaType { + LuaType::Call( + LuaAliasCallType::new( + call_kind, + vec![ + LuaType::Variadic(variadic.clone()), + LuaType::IntegerConst(index), + ], + ) + .into(), + ) +} + +fn get_overload_row_slot_if_present(row: &[LuaType], idx: usize) -> Option { + let row_len = row.len(); + if row_len == 0 { + return None; + } + + if idx + 1 < row_len { + return Some(row[idx].clone()); + } + + let last_idx = row_len - 1; + let last_ty = &row[last_idx]; + let offset = idx - last_idx; + if let LuaType::Variadic(variadic) = last_ty { + if let Some(slot) = variadic.get_type(offset).cloned() { + if slot.contain_tpl() { + if offset > 0 && matches!(variadic.as_ref(), VariadicType::Base(_)) { + return Some(overload_row_tpl_slot( + LuaAliasCallKind::Select, + variadic, + (offset + 1) as i64, + )); + } + + return Some(overload_row_tpl_slot( + LuaAliasCallKind::Index, + variadic, + offset as i64, + )); + } + return Some(slot); + } + + if variadic.get_max_len().is_some() { + return None; + } + + Some(overload_row_tpl_slot( + LuaAliasCallKind::Select, + variadic, + (offset + 1) as i64, + )) + } else if offset == 0 { + Some(last_ty.clone()) + } else { + None + } +} diff --git a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs index 46983b02c..a9b2a7c0a 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/mod.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/mod.rs @@ -1,5 +1,4 @@ mod async_state; -mod return_rows; #[allow(clippy::module_inception)] mod signature; diff --git a/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs b/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs deleted file mode 100644 index 2ae56aacc..000000000 --- a/crates/emmylua_code_analysis/src/db_index/signature/return_rows.rs +++ /dev/null @@ -1,237 +0,0 @@ -use std::sync::Arc; - -use crate::{ - LuaAliasCallKind, LuaAliasCallType, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaType, - VariadicType, db_index::union_type_shallow, -}; - -pub(super) fn get_return_type( - return_docs: &[LuaDocReturnInfo], - return_overloads: &[LuaDocReturnOverloadInfo], -) -> LuaType { - let return_docs_type = row_to_return_type( - return_docs - .iter() - .map(|info| info.type_ref.clone()) - .collect(), - ); - if return_overloads.is_empty() { - return return_docs_type; - } - - let overload_return_type = rows_to_return_type( - &return_overloads - .iter() - .map(|overload| overload.type_refs.as_slice()) - .collect::>(), - ); - if return_docs.is_empty() { - overload_return_type - } else { - merge_return_type(overload_return_type, return_docs_type) - } -} - -pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType { - get_overload_row_slot_if_present(row, idx).unwrap_or(LuaType::Nil) -} - -pub(crate) fn row_to_return_type(mut row: Vec) -> LuaType { - match row.len() { - 0 => LuaType::Nil, - 1 => row.pop().unwrap_or(LuaType::Nil), - _ => LuaType::Variadic(VariadicType::Multi(row).into()), - } -} - -pub(crate) fn return_type_to_row(return_type: LuaType) -> Vec { - match return_type { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Multi(types) => types.clone(), - VariadicType::Base(_) => vec![LuaType::Variadic(variadic)], - }, - typ => vec![typ], - } -} - -fn rows_to_return_type(rows: &[&[LuaType]]) -> LuaType { - let Some(base_max_len) = rows.iter().map(|row| row.len()).max() else { - return LuaType::Nil; - }; - if base_max_len == 0 { - return LuaType::Nil; - } - - let (has_variadic_tail, has_unbounded_variadic_tail, has_tpl_unbounded_variadic_tail) = - rows.iter().fold( - (false, false, false), - |(has_var, has_unbounded, has_tpl_unbounded), row| { - let Some(last) = row.last() else { - return (has_var, has_unbounded, has_tpl_unbounded); - }; - let LuaType::Variadic(variadic) = last else { - return (has_var, has_unbounded, has_tpl_unbounded); - }; - - let has_unbounded_row = variadic.get_max_len().is_none(); - ( - true, - has_unbounded || has_unbounded_row, - has_tpl_unbounded || (has_unbounded_row && variadic.contain_tpl()), - ) - }, - ); - let max_len = if has_variadic_tail { - base_max_len + 1 - } else { - base_max_len - }; - let fill_missing_with_nil = |idx: usize| idx < base_max_len || has_unbounded_variadic_tail; - - let mut types = Vec::with_capacity(max_len); - for idx in 0..max_len { - let slot_types = rows - .iter() - .filter_map(|row| { - get_overload_row_slot_if_present(row, idx) - .or(fill_missing_with_nil(idx).then_some(LuaType::Nil)) - }) - .collect(); - types.push(LuaType::from_vec(slot_types)); - } - if has_unbounded_variadic_tail - && !has_tpl_unbounded_variadic_tail - && let Some(last) = types.last_mut() - && !matches!(last, LuaType::Variadic(_)) - { - *last = LuaType::Variadic(VariadicType::Base(last.clone()).into()); - } - - row_to_return_type(types) -} - -fn merge_return_rows(left_row: &[LuaType], right_row: &[LuaType]) -> LuaType { - let base_max_len = left_row.len().max(right_row.len()); - let (has_variadic_tail, has_unbounded_variadic_tail, has_tpl_unbounded_variadic_tail) = - [left_row, right_row].iter().fold( - (false, false, false), - |(has_var, has_unbounded, has_tpl_unbounded), row| { - let Some(last) = row.last() else { - return (has_var, has_unbounded, has_tpl_unbounded); - }; - let LuaType::Variadic(variadic) = last else { - return (has_var, has_unbounded, has_tpl_unbounded); - }; - - let has_unbounded_row = variadic.get_max_len().is_none(); - ( - true, - has_unbounded || has_unbounded_row, - has_tpl_unbounded || (has_unbounded_row && variadic.contain_tpl()), - ) - }, - ); - let max_len = if has_variadic_tail { - base_max_len + 1 - } else { - base_max_len - }; - let fill_missing_with_nil = |idx: usize| idx < base_max_len || has_unbounded_variadic_tail; - - let mut types = Vec::with_capacity(max_len); - for idx in 0..max_len { - let left_type = get_overload_row_slot_if_present(left_row, idx) - .or(fill_missing_with_nil(idx).then_some(LuaType::Nil)); - let right_type = get_overload_row_slot_if_present(right_row, idx) - .or(fill_missing_with_nil(idx).then_some(LuaType::Nil)); - - let merged_type = match (left_type, right_type) { - (Some(left), Some(right)) => union_type_shallow(left, right), - (Some(left), None) | (None, Some(left)) => left, - (None, None) => continue, - }; - types.push(merged_type); - } - if has_unbounded_variadic_tail - && !has_tpl_unbounded_variadic_tail - && let Some(last) = types.last_mut() - && !matches!(last, LuaType::Variadic(_)) - { - *last = LuaType::Variadic(VariadicType::Base(last.clone()).into()); - } - - row_to_return_type(types) -} - -fn merge_return_type(left: LuaType, right: LuaType) -> LuaType { - match (&left, &right) { - (LuaType::Variadic(_), _) | (_, LuaType::Variadic(_)) => { - let left_row = return_type_to_row(left); - let right_row = return_type_to_row(right); - merge_return_rows(&left_row, &right_row) - } - _ => union_type_shallow(left, right), - } -} - -fn overload_row_tpl_slot( - call_kind: LuaAliasCallKind, - variadic: &Arc, - index: i64, -) -> LuaType { - LuaType::Call( - LuaAliasCallType::new( - call_kind, - vec![ - LuaType::Variadic(variadic.clone()), - LuaType::IntegerConst(index), - ], - ) - .into(), - ) -} - -fn get_overload_row_slot_if_present(row: &[LuaType], idx: usize) -> Option { - let row_len = row.len(); - if row_len == 0 { - return None; - } - - if idx + 1 < row_len { - return Some(row[idx].clone()); - } - - let last_idx = row_len - 1; - let last_ty = &row[last_idx]; - let offset = idx - last_idx; - if let LuaType::Variadic(variadic) = last_ty { - if let Some(slot) = variadic.get_type(offset).cloned() { - if slot.contain_tpl() { - if offset > 0 && matches!(variadic.as_ref(), VariadicType::Base(_)) { - return Some(overload_row_tpl_slot( - LuaAliasCallKind::Select, - variadic, - (offset + 1) as i64, - )); - } - - return Some(overload_row_tpl_slot( - LuaAliasCallKind::Index, - variadic, - offset as i64, - )); - } - return Some(slot); - } - - Some(overload_row_tpl_slot( - LuaAliasCallKind::Select, - variadic, - (offset + 1) as i64, - )) - } else if offset == 0 { - Some(last_ty.clone()) - } else { - None - } -} diff --git a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs index d2a6b818d..98f1fbfa8 100644 --- a/crates/emmylua_code_analysis/src/db_index/signature/signature.rs +++ b/crates/emmylua_code_analysis/src/db_index/signature/signature.rs @@ -6,11 +6,13 @@ use std::{collections::HashMap, sync::Arc}; use emmylua_parser::{LuaAstNode, LuaClosureExpr, LuaDocFuncType}; use rowan::TextSize; -use super::return_rows; use crate::db_index::signature::async_state::AsyncState; use crate::{ FileId, - db_index::{LuaFunctionType, LuaType}, + db_index::{ + LuaFunctionType, LuaType, + return_row::{merge_return_rows, merge_return_rows_shallow, row_to_return_type}, + }, }; use crate::{ LuaAttributeCollectionExt, LuaAttributeUse, LuaBuiltinAttributeKind, SemanticModel, @@ -117,19 +119,31 @@ impl LuaSignature { } pub fn get_return_type(&self) -> LuaType { - return_rows::get_return_type(&self.return_docs, &self.return_overloads) + row_to_return_type(self.get_return_row()) } - pub(crate) fn get_overload_row_slot(row: &[LuaType], idx: usize) -> LuaType { - return_rows::get_overload_row_slot(row, idx) - } - - pub(crate) fn row_to_return_type(row: Vec) -> LuaType { - return_rows::row_to_return_type(row) - } + pub fn get_return_row(&self) -> Vec { + let return_docs_row = self + .return_docs + .iter() + .map(|info| info.type_ref.clone()) + .collect::>(); + if self.return_overloads.is_empty() { + return return_docs_row; + } - pub(crate) fn return_type_to_row(return_type: LuaType) -> Vec { - return_rows::return_type_to_row(return_type) + let overload_return_row = merge_return_rows( + &self + .return_overloads + .iter() + .map(|overload| overload.type_refs.as_slice()) + .collect::>(), + ); + if self.return_docs.is_empty() { + overload_return_row + } else { + merge_return_rows_shallow(&[overload_return_row.as_slice(), return_docs_row.as_slice()]) + } } pub fn is_method(&self, semantic_model: &SemanticModel, owner_type: Option<&LuaType>) -> bool { @@ -163,17 +177,13 @@ impl LuaSignature { } pub fn to_doc_func_type(&self) -> Arc { - let params = self.get_type_params(); - let return_type = self.get_return_type(); - let is_vararg = self.is_vararg; - let func_type = LuaFunctionType::new( + Arc::new(LuaFunctionType::new( self.async_state, self.is_colon_define, - is_vararg, - params, - return_type, - ); - Arc::new(func_type) + self.is_vararg, + self.get_type_params(), + self.get_return_row(), + )) } pub fn to_call_operator_func_type(&self) -> Arc { @@ -182,9 +192,13 @@ impl LuaSignature { params.remove(0); } - let return_type = self.get_return_type(); - let func_type = - LuaFunctionType::new(self.async_state, false, self.is_vararg, params, return_type); + let func_type = LuaFunctionType::new( + self.async_state, + false, + self.is_vararg, + params, + self.get_return_row(), + ); Arc::new(func_type) } } diff --git a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs index f26732263..6afbeff78 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/humanize_type.rs @@ -571,8 +571,8 @@ impl<'a> TypeHumanizer<'a> { self.level = saved; w.write_char(')')?; - let ret_type = lua_func.get_ret(); - let return_nil = match ret_type { + let ret_type = lua_func.get_return_type(); + let return_nil = match &ret_type { LuaType::Variadic(variadic) => matches!(variadic.get_type(0), Some(LuaType::Nil)), _ => ret_type.is_nil(), }; @@ -584,7 +584,7 @@ impl<'a> TypeHumanizer<'a> { w.write_str(" -> ")?; let saved = self.level; self.level = self.child_level(); - self.write_type(ret_type, w)?; + self.write_type(&ret_type, w)?; self.level = saved; Ok(()) } diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs index 1e7750c37..1ccbfcfbe 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/complex.rs @@ -2,9 +2,9 @@ use hashbrown::{HashMap, HashSet}; use internment::ArcIntern; use rowan::TextRange; use smol_str::SmolStr; -use std::{ops::Deref, sync::Arc}; +use std::sync::Arc; -use crate::db_index::LuaMemberKey; +use crate::db_index::{LuaMemberKey, return_row::row_to_return_type}; use crate::{AsyncState, DbIndex, InFiled, SemanticModel, first_param_may_not_self}; use super::super::basic_union::{BasicTypeKind, BasicTypeUnion}; @@ -108,7 +108,7 @@ pub struct LuaFunctionType { is_colon_define: bool, is_variadic: bool, params: Vec<(String, Option)>, - ret: LuaType, + ret: Vec, } impl LuaFunctionType { @@ -117,7 +117,7 @@ impl LuaFunctionType { is_colon_define: bool, is_variadic: bool, params: Vec<(String, Option)>, - ret: LuaType, + ret: Vec, ) -> Self { Self { async_state, @@ -140,20 +140,25 @@ impl LuaFunctionType { &self.params } - pub fn get_ret(&self) -> &LuaType { + pub fn get_return_row(&self) -> &[LuaType] { &self.ret } + pub fn get_return_type(&self) -> LuaType { + row_to_return_type(self.ret.clone()) + } + pub fn is_variadic(&self) -> bool { self.is_variadic } pub fn get_variadic_ret(&self) -> VariadicType { - if let LuaType::Variadic(variadic) = &self.ret { - return variadic.deref().clone(); + match self.ret.as_slice() { + [] => VariadicType::Base(LuaType::Nil), + [LuaType::Variadic(variadic)] => variadic.as_ref().clone(), + [ret] => VariadicType::Base(ret.clone()), + row => VariadicType::Multi(row.to_vec()), } - - VariadicType::Base(self.ret.clone()) } pub fn contain_tpl(&self) -> bool { @@ -166,7 +171,7 @@ impl LuaFunctionType { .params .iter() .any(|(name, t)| name == "self" || t.as_ref().is_some_and(|t| t.is_self_infer())) - || self.ret.is_self_infer() + || self.ret.iter().any(|t| t.is_self_infer()) } pub fn is_method(&self, semantic_model: &SemanticModel, owner_type: Option<&LuaType>) -> bool { diff --git a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs index 9cb73bdcc..5b3aaa361 100644 --- a/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs +++ b/crates/emmylua_code_analysis/src/db_index/type/types/traverse.rs @@ -130,7 +130,9 @@ impl LuaTypeNode for LuaTupleType { impl LuaTypeNode for LuaFunctionType { fn push_direct_children<'a>(&'a self, stack: &mut Vec<&'a LuaType>) { - stack.push(self.get_ret()); + for ty in self.get_return_row().iter().rev() { + stack.push(ty); + } for (_, ty) in self.get_params().iter().rev() { if let Some(ty) = ty { stack.push(ty); diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs index 24a0ea141..1f48cdf46 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/assign_type_mismatch.rs @@ -10,6 +10,7 @@ use crate::{ DiagnosticCode, LuaBuiltinAttributeKind, LuaDeclExtra, LuaDeclId, LuaLspOptimizationCode, LuaMemberKey, LuaSemanticDeclId, LuaType, SemanticDeclLevel, SemanticModel, TypeCheckFailReason, TypeCheckResult, VariadicType, infer_index_expr, + semantic::{adjusted_result_slot_type, assignment_rhs_source}, }; use super::{Checker, DiagnosticContext, humanize_lint_type}; @@ -50,7 +51,7 @@ fn check_assign_stat( semantic_model, index_expr, exprs.get(idx).cloned(), - value_types.get(idx)?.0.clone(), + value_type_to_check(semantic_model, &value_types, &exprs, idx)?, ); } LuaVarExpr::NameExpr(name_expr) => { @@ -59,7 +60,7 @@ fn check_assign_stat( semantic_model, name_expr, exprs.get(idx).cloned(), - value_types.get(idx)?.0.clone(), + value_type_to_check(semantic_model, &value_types, &exprs, idx)?, ); } } @@ -67,6 +68,33 @@ fn check_assign_stat( Some(()) } +/// Returns the RHS type this diagnostic should compare with target `idx`. +/// +/// ```lua +/// ---@return +/// local function none() end +/// +/// local a, b = none() +/// ``` +/// +/// A call in tail-list position can contribute no values to `value_types`, but +/// Lua assignment still initializes unmatched targets to `nil`. +fn value_type_to_check( + semantic_model: &SemanticModel, + value_types: &[(LuaType, TextRange)], + value_exprs: &[LuaExpr], + idx: usize, +) -> Option { + if let Some((ty, _)) = value_types.get(idx) { + return Some(ty.clone()); + } + + let (expr_idx, slot) = assignment_rhs_source(value_exprs.len(), idx)?; + let expr = value_exprs.get(expr_idx)?; + let ty = semantic_model.infer_expr(expr.clone()).ok()?; + Some(adjusted_result_slot_type(&ty, slot)) +} + fn check_name_expr( context: &mut DiagnosticContext, semantic_model: &SemanticModel, @@ -182,7 +210,7 @@ fn check_local_stat( .get_type_index() .get_type_cache(&decl_id.into()) .map(|cache| cache.as_type().clone())?; - let value_type = value_types.get(idx)?.0.clone(); + let value_type = value_type_to_check(semantic_model, &value_types, &value_exprs, idx)?; check_assign_type_mismatch( context, semantic_model, diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs index 2e89c8f43..46cb5a248 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_param_count.rs @@ -95,7 +95,18 @@ fn check_call_expr( let func = semantic_model.infer_call_expr_func(call_expr.clone(), None)?; let mut fake_params = func.get_params().to_vec(); let call_args = call_expr.get_args_list()?.get_args().collect::>(); - let mut call_args_count = call_args.len(); + // List inference preserves empty tail-call rows; the direct tail check keeps + // unbounded variadic calls from looking like a single known argument. + let inferred_call_args = semantic_model.infer_expr_list_types(&call_args, None); + let mut min_call_args_count = inferred_call_args.len(); + let mut max_call_args_count = Some(inferred_call_args.len()); + if let Some((tail_arg, fixed_args)) = call_args.split_last() + && let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(tail_arg.clone()) + { + min_call_args_count = fixed_args.len() + variadic.get_min_len().unwrap_or(0); + max_call_args_count = variadic.get_max_len().map(|max| fixed_args.len() + max); + } + let tail_min_count = min_call_args_count.saturating_sub(call_args.len().saturating_sub(1)); let last_arg_is_dots = call_args.last().is_some_and(is_dots_expr); // 根据冒号定义与冒号调用的情况来调整调用参数的数量 let colon_call = call_expr.is_colon_call(); @@ -106,36 +117,17 @@ fn check_call_expr( fake_params.insert(0, ("self".to_string(), Some(LuaType::SelfInfer))); } (true, false) => { - call_args_count += 1; + min_call_args_count += 1; + if let Some(max_call_args_count) = &mut max_call_args_count { + *max_call_args_count += 1; + } } } // Check for missing parameters - if call_args_count < fake_params.len() { - // 调用参数包含 `...` - for arg in call_args.iter() { - if let LuaExpr::LiteralExpr(literal_expr) = arg - && let Some(LuaLiteralToken::Dots(_)) = literal_expr.get_literal() - { - return Some(()); - } - } - // 对调用参数的最后一个参数进行特殊处理 - if let Some(last_arg) = call_args.last() - && let Ok(LuaType::Variadic(variadic)) = semantic_model.infer_expr(last_arg.clone()) - { - let len = match variadic.get_max_len() { - Some(len) => len, - None => { - return Some(()); - } - }; - call_args_count = call_args_count + len - 1; - if call_args_count >= fake_params.len() { - return Some(()); - } - } - + if let Some(call_args_count) = max_call_args_count + && call_args_count < fake_params.len() + { let mut miss_parameter_info = Vec::new(); for i in call_args_count..fake_params.len() { @@ -177,11 +169,6 @@ fn check_call_expr( return Some(()); } - let mut min_call_args_count = call_args_count; - if last_arg_is_dots { - min_call_args_count = min_call_args_count.saturating_sub(1); - } - if min_call_args_count <= fake_params.len() { return Some(()); } @@ -203,6 +190,10 @@ fn check_call_expr( continue; } + if i + 1 == call_args.len() && tail_min_count == 0 { + continue; + } + let param_index = i as isize + adjusted_index; if param_index < 0 || param_index < fake_params.len() as isize { diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs index b9a3709b5..409c82b78 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/check_return_count.rs @@ -6,6 +6,7 @@ use emmylua_parser::{ use crate::{ DiagnosticCode, LuaSignatureId, LuaType, SemanticModel, SignatureReturnStatus, compilation::analyze_func_body_missing_return_flags_with, + db_index::return_row::{return_row_max_len, return_row_min_len}, }; use super::{Checker, DiagnosticContext, get_return_stats}; @@ -28,25 +29,25 @@ impl Checker for CheckReturnCount { } } -// 获取(是否doc标注过返回值, 返回值类型) +// 获取(是否doc标注过返回值, 返回值行) fn get_function_return_info( context: &mut DiagnosticContext, semantic_model: &SemanticModel, closure_expr: &LuaClosureExpr, -) -> Option<(bool, LuaType)> { +) -> Option<(bool, Vec)> { let typ = semantic_model .infer_bind_value_type(closure_expr.clone().into()) .unwrap_or(LuaType::Unknown); match typ { LuaType::DocFunction(func_type) => { - return Some((true, func_type.get_ret().clone())); + return Some((true, func_type.get_return_row().to_vec())); } LuaType::Signature(signature) => { let signature = context.db.get_signature_index().get(&signature)?; return Some(( signature.resolve_return == SignatureReturnStatus::DocResolve, - signature.get_return_type(), + signature.get_return_row(), )); } _ => {} @@ -57,7 +58,7 @@ fn get_function_return_info( Some(( signature.resolve_return == SignatureReturnStatus::DocResolve, - signature.get_return_type(), + signature.get_return_row(), )) } @@ -66,7 +67,7 @@ fn check_missing_return( semantic_model: &SemanticModel, closure_expr: &LuaClosureExpr, ) -> Option<()> { - let (is_doc_resolve_return, return_type) = + let (is_doc_resolve_return, return_row) = get_function_return_info(context, semantic_model, closure_expr)?; // 如果返回状态不是 DocResolve, 则跳过检查 @@ -75,36 +76,16 @@ fn check_missing_return( } // 最小返回值数 - let min_expected_return_count = match &return_type { - LuaType::Variadic(variadic) => { - let min_len = variadic.get_min_len()?; - let mut real_min_len = min_len; - // 逆序检查 - if min_len > 0 { - for i in (0..min_len).rev() { - if let Some(ty) = variadic.get_type(i) { - if ty.is_optional() { - real_min_len -= 1; - } else { - break; - } - } - } - } - real_min_len - } - LuaType::Nil | LuaType::Any | LuaType::Unknown => 0, - _ if return_type.is_nullable() => 0, - _ => 1, - }; + let min_expected_return_count = return_row_min_len(&return_row)?; + let max_expected_return_count = return_row_max_len(&return_row); for return_stat in get_return_stats(closure_expr) { check_return_count( context, semantic_model, &return_stat, - &return_type, min_expected_return_count, + max_expected_return_count, ); } @@ -159,47 +140,27 @@ fn check_return_count( context: &mut DiagnosticContext, semantic_model: &SemanticModel, return_stat: &LuaReturnStat, - return_type: &LuaType, min_expected_return_count: usize, + max_expected_return_count: Option, ) -> Option<()> { - let max_expected_return_count = match return_type { - LuaType::Variadic(variadic) => variadic.get_max_len(), - LuaType::Any | LuaType::Unknown => Some(1), - LuaType::Nil => Some(0), - _ => Some(1), - }; - // 计算实际返回的表达式数量并记录多余的范围 let expr_list = return_stat.get_expr_list().collect::>(); - let mut total_return_count = 0; - let mut tail_return_nil = false; - let mut redundant_ranges = Vec::new(); - - for (index, expr) in expr_list.iter().enumerate() { - let expr_type = semantic_model - .infer_expr(expr.clone()) - .unwrap_or(LuaType::Unknown); - match expr_type { - LuaType::Variadic(variadic) => { - total_return_count += variadic.get_max_len()?; - } - LuaType::Nil => { - if index == expr_list.len() - 1 { - tail_return_nil = true; - } - total_return_count += 1; - } - _ => total_return_count += 1, - }; - - if max_expected_return_count.is_some() && total_return_count > max_expected_return_count? { - if tail_return_nil && total_return_count - 1 == max_expected_return_count? { - continue; - } - redundant_ranges.push(expr.get_range()); - } + let tail_expr_type = expr_list + .last() + .and_then(|expr| semantic_model.infer_expr(expr.clone()).ok()); + if let Some(LuaType::Variadic(variadic)) = &tail_expr_type + && variadic.get_max_len().is_none() + { + // An unbounded variadic tail such as `string...` is a possible return row, + // not a proven count, so count diagnostics would be guesswork. + return None; } + // Count the adjusted return row so a no-return tail call contributes zero + // slots instead of the scalar `nil` used in single-value expression contexts. + let return_infos = semantic_model.infer_expr_list_types(&expr_list, None); + let total_return_count = return_infos.len(); + // 检查缺失的返回值 if total_return_count < min_expected_return_count { context.add_diagnostic( @@ -216,18 +177,32 @@ fn check_return_count( } // 检查多余的返回值 - for range in redundant_ranges { - context.add_diagnostic( - DiagnosticCode::RedundantReturnValue, - range, - t!( - "Annotations specify that at most %{max} return value(s) are required, found %{rmax} returned here instead.", - max = max_expected_return_count?, - rmax = total_return_count - ) - .to_string(), - None, - ); + if let Some(max_expected_return_count) = max_expected_return_count + && total_return_count > max_expected_return_count + { + let mut last_redundant_range = None; + for (index, (_, range)) in return_infos.iter().enumerate() { + if index < max_expected_return_count { + continue; + } + + if last_redundant_range == Some(*range) { + continue; + } + + context.add_diagnostic( + DiagnosticCode::RedundantReturnValue, + *range, + t!( + "Annotations specify that at most %{max} return value(s) are required, found %{rmax} returned here instead.", + max = max_expected_return_count, + rmax = total_return_count + ) + .to_string(), + None, + ); + last_redundant_range = Some(*range); + } } Some(()) diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs index d024a6863..0508af075 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/incomplete_signature_doc.rs @@ -6,7 +6,7 @@ use emmylua_parser::{ use crate::{ DiagnosticCode, LuaSemanticDeclId, LuaSignature, LuaSignatureId, LuaType, SemanticDeclLevel, - SemanticModel, SignatureReturnStatus, + SemanticModel, SignatureReturnStatus, db_index::return_row::return_row_max_len, }; use super::{Checker, DiagnosticContext, get_closure_expr_comment, get_return_stats}; @@ -199,37 +199,43 @@ fn check_returns( function_name: &str, ) -> Option<()> { for return_stat in get_return_stats(closure_expr) { - let mut return_stat_len: usize = 0; - - for (i, expr) in return_stat.get_expr_list().enumerate() { - let Some(infer_type) = semantic_model.infer_expr(expr.clone()).ok() else { - continue; - }; - - let expr_return_count = match infer_type { - LuaType::Variadic(variadic) => variadic.get_min_len()?, - _ => 1, - }; + let expr_list = return_stat.get_expr_list().collect::>(); + let tail_expr_type = expr_list + .last() + .and_then(|expr| semantic_model.infer_expr(expr.clone()).ok()); + if let Some(LuaType::Variadic(variadic)) = &tail_expr_type + && variadic.get_max_len().is_none() + { + continue; + } - return_stat_len += expr_return_count; + let return_infos = semantic_model.infer_expr_list_types(&expr_list, None); + let mut last_missing_range = None; + for (i, (_, range)) in return_infos.iter().enumerate() { + let return_stat_len = i + 1; if let Some(doc_return_len) = doc_return_len && return_stat_len > doc_return_len { + if last_missing_range == Some(*range) { + continue; + } + let message = if is_global { t!( "Missing @return annotation at index `%{index}` in global function `%{function_name}`.", - index = i + 1, + index = return_stat_len, function_name = function_name ) } else { t!( "Incomplete signature. Missing @return annotation at index `%{index}`.", - index = i + 1 + index = return_stat_len ) }; - context.add_diagnostic(code, expr.get_range(), message.to_string(), None); + context.add_diagnostic(code, *range, message.to_string(), None); + last_missing_range = Some(*range); } } } @@ -241,12 +247,6 @@ fn get_doc_return_max_len(signature: &LuaSignature) -> Option> { if signature.resolve_return != SignatureReturnStatus::DocResolve { return None; } - let return_type = signature.get_return_type(); - - Some(match return_type { - LuaType::Variadic(variadic) => variadic.get_max_len(), - LuaType::Any | LuaType::Unknown => Some(1), - LuaType::Nil => Some(0), - _ => Some(1), - }) + + Some(return_row_max_len(&signature.get_return_row())) } diff --git a/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs b/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs index 3b37026e2..5dfd29f40 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/checker/return_type_mismatch.rs @@ -60,6 +60,9 @@ fn check_return_stat( ) -> Option<()> { let return_exprs = return_stat.get_expr_list().collect::>(); let (return_expr_types, return_expr_ranges) = { + // This checker compares slot types, not row shape. A `string...` tail + // call is not the same row as `string, string`, but each proven slot is + // still `string`; any uncertain count belongs to return-count checks. let infos = semantic_model.infer_expr_list_types(&return_exprs, None); let mut return_expr_types = infos.iter().map(|(typ, _)| typ.clone()).collect::>(); // 解决 setmetatable 的返回值类型问题 diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs index b12e8daa8..41491753a 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs @@ -42,6 +42,43 @@ mod tests { )); } + #[test] + fn test_empty_return_call_local_decl_checks_nil_assignment() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@return + local function none() + end + + ---@type string + local value = none() + "# + )); + } + + #[test] + fn test_empty_return_call_assignment_checks_nil_assignment() { + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@return + local function none() + end + + ---@type nil + local first + + ---@type string + local second + + first, second = none() + "# + )); + } + // #[test] // fn test_3() { // let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs index 49442418f..1d5242b72 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/check_return_count_test.rs @@ -32,6 +32,26 @@ mod tests { )); } + #[test] + fn test_callback_function_type_without_return_list_rejects_return_value() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::RedundantReturnValue, + r#" + ---@param cb fun() + local function call(cb) + end + + call(function() + while true do + return 1 + end + end) + "# + )); + } + #[test] fn test_2() { let mut ws = VirtualWorkspace::new(); @@ -75,6 +95,25 @@ mod tests { )); } + #[test] + fn test_tail_empty_return_call_counts_as_missing_return_value() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::MissingReturnValue, + r#" + ---@return + local function none() + end + + ---@return string + local function value() + return none() + end + "# + )); + } + #[test] fn test_missing_return_value_variadic() { let mut ws = VirtualWorkspace::new(); @@ -153,6 +192,82 @@ mod tests { )); } + #[test] + fn test_tail_call_trailing_nil_counts_as_redundant_return_value() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::RedundantReturnValue, + r#" + ---@return number, nil + local function source() + return 1, nil + end + + ---@return number + local function target() + return source() + end + "# + )); + } + + #[test] + fn test_tail_call_nil_return_counts_as_redundant_return_value() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::RedundantReturnValue, + r#" + ---@return nil + local function source() + return nil + end + + ---@param cb fun() + local function call(cb) + end + + call(function() + return source() + end) + "# + )); + } + + #[test] + fn test_nil_return_annotation_requires_nil_return_value() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::MissingReturnValue, + r#" + ---@return nil + local function target() + return + end + "# + )); + } + + #[test] + fn test_zero_return_callback_rejects_nil_return_value() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::RedundantReturnValue, + r#" + ---@param cb fun() + local function call(cb) + end + + call(function() + return nil + end) + "# + )); + } + #[test] fn test_not_return_anno() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs index a4d5398b5..440a9c9f4 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/generic_constraint_mismatch_test.rs @@ -58,6 +58,39 @@ mod test { )); } + #[test] + fn test_empty_return_tail_call_does_not_check_missing_generic_arg() { + let mut ws = VirtualWorkspace::new(); + assert!(ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@return + local function none() + end + + ---@generic T: string + ---@param value T + local function takes(value) + end + + takes(none()) + "# + )); + + let mut ws = VirtualWorkspace::new(); + assert!(!ws.has_no_diagnostic( + DiagnosticCode::GenericConstraintMismatch, + r#" + ---@generic T: string + ---@param value T + local function takes(value) + end + + takes(nil) + "# + )); + } + #[test] fn test_4() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs index 19e7d3750..6f393dd67 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/incomplete_signature_doc_test.rs @@ -122,6 +122,42 @@ mod tests { )); } + #[test] + fn test_nil_return_annotation_counts_as_documented_return() { + let mut ws = VirtualWorkspace::new(); + ws.enable_full_diagnostic(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::IncompleteSignatureDoc, + r#" + ---@return nil + local function f() + return nil + end + "# + )); + } + + #[test] + fn test_empty_tail_call_return_counts_as_documented_zero_returns() { + let mut ws = VirtualWorkspace::new(); + ws.enable_full_diagnostic(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::IncompleteSignatureDoc, + r#" + ---@return + local function none() + end + + ---@return + local function f() + return none() + end + "# + )); + } + #[test] fn test_variadic_return_overload_does_not_trigger_incomplete_signature_doc() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs index 501fd975d..9b3405262 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/missing_parameter_test.rs @@ -148,6 +148,47 @@ mod test { )); } + #[test] + fn test_empty_return_tail_call_counts_as_missing_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(!ws.has_no_diagnostic( + DiagnosticCode::MissingParameter, + r#" + ---@return + local function none() + end + + ---@param value string + local function takes(value) + end + + takes(none()) + "# + )); + } + + #[test] + fn test_unbounded_tail_call_does_not_count_as_missing_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::MissingParameter, + r#" + ---@return string... + local function many() + end + + ---@param first string + ---@param second string + local function takes(first, second) + end + + takes(many()) + "# + )); + } + #[test] fn test_alias() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs index b1ae08140..4036c4cdc 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/redundant_parameter_test.rs @@ -106,6 +106,44 @@ mod test { )); } + #[test] + fn test_empty_return_tail_call_does_not_count_as_redundant_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::RedundantParameter, + r#" + ---@return + local function none() + end + + local function takes_none() + end + + takes_none(none()) + "# + )); + } + + #[test] + fn test_unbounded_tail_call_does_not_count_as_redundant_parameter() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::RedundantParameter, + r#" + ---@return string... + local function many() + end + + local function takes_none() + end + + takes_none(many()) + "# + )); + } + #[test] fn test_function_param() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs index 26b8d527a..8a4c94f0f 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs @@ -298,6 +298,25 @@ mod tests { )); } + #[test] + fn test_variadic_tail_call_return_slots_do_not_type_mismatch() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.has_no_diagnostic( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@return string... + local function many() + end + + ---@return string, string + local function test() + return many() + end + "# + )); + } + #[test] fn test_issue_146() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs index 1ad1c18ef..39120513b 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/call_constraint.rs @@ -1,4 +1,4 @@ -use std::{ops::Deref, sync::Arc}; +use std::sync::Arc; use emmylua_parser::{LuaAstNode, LuaAstToken, LuaCallExpr, LuaExpr, LuaIndexExpr}; use hashbrown::HashSet; @@ -6,7 +6,7 @@ use rowan::TextRange; use crate::{ DbIndex, DocTypeInferContext, GenericTpl, GenericTplId, LuaFunctionType, LuaSemanticDeclId, - LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, VariadicType, + LuaType, LuaTypeNode, SemanticDeclLevel, SemanticModel, TypeOps, TypeSubstitutor, infer_doc_type, }; @@ -356,15 +356,16 @@ fn get_arg_infos( call_expr: &LuaCallExpr, ) -> Option> { let arg_exprs = call_expr.get_args_list()?.get_args().collect::>(); - let arg_infos = infer_expr_list_types(semantic_model, &arg_exprs) + let arg_infos = semantic_model + .infer_expr_list_types(&arg_exprs, None) .into_iter() - .map(|(raw_type, expr)| { + .map(|(raw_type, range)| { let check_type = get_constraint_type(semantic_model, &raw_type, 0) .unwrap_or_else(|| raw_type.clone()); CallConstraintArg { raw_type, check_type, - range: expr.get_range(), + range, } }) .collect(); @@ -427,33 +428,6 @@ fn get_constraint_type( } } -// 将多个表达式推导为具体类型列表 -fn infer_expr_list_types( - semantic_model: &SemanticModel, - exprs: &[LuaExpr], -) -> Vec<(LuaType, LuaExpr)> { - let mut value_types = Vec::new(); - for expr in exprs.iter() { - let expr_type = semantic_model - .infer_expr(expr.clone()) - .unwrap_or(LuaType::Unknown); - match expr_type { - LuaType::Variadic(variadic) => match variadic.deref() { - VariadicType::Base(base) => { - value_types.push((base.clone(), expr.clone())); - } - VariadicType::Multi(vecs) => { - for typ in vecs { - value_types.push((typ.clone(), expr.clone())); - } - } - }, - _ => value_types.push((expr_type.clone(), expr.clone())), - } - } - value_types -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs index 8ac1a2644..ecb720b8c 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/complete_generic_args.rs @@ -306,7 +306,15 @@ fn complete_doc_function( (name.clone(), completed.map(|completed| completed.ty)) }) .collect(); - let ret = complete_type_generic_args_in_type_inner(db, func.get_ret(), visiting); + let returns = func + .get_return_row() + .iter() + .map(|ty| { + let completed = complete_type_generic_args_in_type_inner(db, ty, visiting); + cycled |= completed.cycled; + completed.ty + }) + .collect(); CompletedType::new( LuaType::DocFunction( LuaFunctionType::new( @@ -314,11 +322,11 @@ fn complete_doc_function( func.is_colon_define(), func.is_variadic(), params, - ret.ty, + returns, ) .into(), ), - cycled || ret.cycled, + cycled, ) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs index cd739e582..53c36c4f3 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_conditional_generic.rs @@ -335,15 +335,10 @@ fn collect_infer_assignments( } } - let pattern_ret = pattern_func.get_ret(); - if contains_conditional_infer(pattern_ret) { - collect_infer_assignments( - db, - source_func.get_ret(), - pattern_ret, - assignments, - variance, - ) + let pattern_ret = pattern_func.get_return_type(); + if contains_conditional_infer(&pattern_ret) { + let source_ret = source_func.get_return_type(); + collect_infer_assignments(db, &source_ret, &pattern_ret, assignments, variance) } else { true } @@ -672,7 +667,11 @@ fn actualize_unresolved_templates(ty: LuaType) -> LuaType { (name.clone(), ty.clone().map(actualize_unresolved_templates)) }) .collect(), - actualize_unresolved_templates(func.get_ret().clone()), + func.get_return_row() + .iter() + .cloned() + .map(actualize_unresolved_templates) + .collect(), ) .into(), ), diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index 1c8e39996..2e0ee84a5 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -7,7 +7,10 @@ use crate::semantic::infer::infer_expr_list_types; use crate::{ DocTypeInferContext, FileId, GenericParam, GenericTplId, LuaFunctionType, LuaGenericType, LuaTypeNode, - db_index::{DbIndex, LuaType}, + db_index::{ + DbIndex, LuaType, + return_row::{merge_return_rows, row_to_multi_return_type}, + }, infer_doc_type, semantic::{ LuaInferCache, @@ -124,14 +127,14 @@ fn infer_callable_return_from_arg_types( context: &mut TplContext, callable_type: &LuaType, call_arg_types: &[LuaType], -) -> Result, InferFailReason> { +) -> Result>, InferFailReason> { let mut overload_groups = Vec::new(); collect_callable_overload_groups(context.db, callable_type, &mut overload_groups)?; if overload_groups.is_empty() { return Ok(None); } - let mut member_returns = Vec::new(); + let mut member_return_rows = Vec::new(); for overloads in &overload_groups { let instantiated_overloads = overloads .iter() @@ -159,12 +162,11 @@ fn infer_callable_return_from_arg_types( .iter() .any(|arg_type| arg_type.is_any() || arg_type.is_unknown()); if unresolved_arg_match { - member_returns.push(LuaType::from_vec( - overloads_to_resolve - .iter() - .map(|callable| callable.get_ret().clone()) - .collect(), - )); + let rows = overloads_to_resolve + .iter() + .map(|callable| callable.get_return_row()) + .collect::>(); + member_return_rows.push(merge_return_rows(&rows)); continue; } @@ -176,13 +178,18 @@ fn infer_callable_return_from_arg_types( None, &[], ); - member_returns.push(callable?.get_ret().clone()); + member_return_rows.push(callable?.get_return_row().to_vec()); } - if member_returns.is_empty() { + if member_return_rows.is_empty() { return Ok(None); } - Ok(Some(LuaType::from_vec(member_returns))) + Ok(Some(merge_return_rows( + &member_return_rows + .iter() + .map(|row| row.as_slice()) + .collect::>(), + ))) } fn uses_erased_function_param(callable: &LuaFunctionType, call_arg_types: &[LuaType]) -> bool { @@ -201,7 +208,7 @@ pub fn infer_callable_return_from_remaining_args( context: &mut TplContext, callable_type: &LuaType, arg_exprs: &[LuaExpr], -) -> Result, InferFailReason> { +) -> Result>, InferFailReason> { if arg_exprs.is_empty() { return Ok(None); } @@ -275,13 +282,15 @@ fn instantiate_callable_from_arg_types( }; let unresolved_return_tpls = { let mut tpl_ids = HashSet::new(); - instantiated.get_ret().visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty - && callable_tpls.contains(&generic_tpl.get_tpl_id()) - { - tpl_ids.insert(generic_tpl.get_tpl_id()); - } - }); + for ret in instantiated.get_return_row() { + ret.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + && callable_tpls.contains(&generic_tpl.get_tpl_id()) + { + tpl_ids.insert(generic_tpl.get_tpl_id()); + } + }); + } if tpl_ids.is_empty() { return Some(instantiated); } @@ -346,14 +355,16 @@ fn collect_callback_return_tpls( let Ok(Some(param_func)) = as_doc_function_type(db, param_type) else { continue; }; - param_func.get_ret().visit_type(&mut |ty| { - if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { - let tpl_id = generic_tpl.get_tpl_id(); - if unresolved_return_tpls.contains(&tpl_id) { - callback_return_tpls.insert(tpl_id); + for ret in param_func.get_return_row() { + ret.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { + let tpl_id = generic_tpl.get_tpl_id(); + if unresolved_return_tpls.contains(&tpl_id) { + callback_return_tpls.insert(tpl_id); + } } - } - }); + }); + } } callback_return_tpls @@ -513,11 +524,12 @@ fn infer_generic_types_from_call( }; if let Some(return_pattern) = - as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_ret().clone()) + as_doc_function_type(context.db, func_param_type)?.map(|func| func.get_return_type()) { - if let Some(inferred_return_type) = + if let Some(inferred_return_row) = infer_callable_return_from_remaining_args(context, &arg_type, &arg_exprs[i + 1..])? { + let inferred_return_type = row_to_multi_return_type(inferred_return_row); return_type_pattern_match_target_type( context, &return_pattern, @@ -530,11 +542,19 @@ fn infer_generic_types_from_call( match (func_param_type, &arg_type) { (LuaType::Variadic(variadic), _) => { - let mut arg_types = vec![]; - for arg_expr in &arg_exprs[i..] { - let arg_type = infer_expr(db, context.cache, arg_expr.clone())?; - arg_types.push(arg_type); - } + let arg_types = infer_expr_list_types( + context.db, + context.cache, + &arg_exprs[i..], + None, + |db, cache, expr| match infer_expr(db, cache, expr) { + Err(InferFailReason::FieldNotFound) => Ok(LuaType::Nil), + result => result, + }, + )? + .into_iter() + .map(|(ty, _)| ty) + .collect::>(); variadic_tpl_pattern_match(context, variadic, &arg_types)?; break; } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs index ae9d07e77..02f938386 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/mod.rs @@ -243,7 +243,6 @@ fn instantiate_doc_function_with_context( doc_func: &LuaFunctionType, ) -> LuaType { let tpl_func_params = doc_func.get_params(); - let tpl_ret = doc_func.get_ret(); let async_state = doc_func.get_async_state(); let colon_define = doc_func.is_colon_define(); @@ -342,19 +341,31 @@ fn instantiate_doc_function_with_context( } } - let mut inst_ret_type = instantiate_type_generic_with_context(context, tpl_ret); - // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple - if let LuaType::Variadic(_) = &&tpl_ret - && let LuaType::Tuple(tuple) = &inst_ret_type - { - match tuple.len() { - 0 => {} - 1 => inst_ret_type = tuple.get_types()[0].clone(), - _ => { - inst_ret_type = - LuaType::Variadic(VariadicType::Multi(tuple.get_types().to_vec()).into()) + let mut inst_ret = Vec::new(); + for origin_ret_type in doc_func.get_return_row() { + let mut inst_ret_type = instantiate_type_generic_with_context(context, origin_ret_type); + // 对于可变返回值, 如果实例化是 tuple, 那么我们将展开 tuple + if let LuaType::Variadic(_) = origin_ret_type + && let LuaType::Tuple(tuple) = &inst_ret_type + { + match tuple.len() { + 0 => continue, + 1 => inst_ret_type = tuple.get_types()[0].clone(), + _ => { + inst_ret_type = + LuaType::Variadic(VariadicType::Multi(tuple.get_types().to_vec()).into()) + } } } + if let LuaType::Variadic(_) = origin_ret_type + && let LuaType::Variadic(variadic) = &inst_ret_type + && let VariadicType::Multi(types) = variadic.deref() + { + // A variadic return can instantiate to an empty row. + inst_ret.extend(types.iter().cloned()); + continue; + } + inst_ret.push(inst_ret_type); } // 重新判断是否是可变参数 let is_variadic = new_params @@ -369,14 +380,7 @@ fn instantiate_doc_function_with_context( }); LuaType::DocFunction( - LuaFunctionType::new( - async_state, - colon_define, - is_variadic, - new_params, - inst_ret_type, - ) - .into(), + LuaFunctionType::new(async_state, colon_define, is_variadic, new_params, inst_ret).into(), ) } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs index 77900f2d4..2e60cdcb5 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/tpl_pattern/mod.rs @@ -620,9 +620,17 @@ fn func_tpl_pattern_match_doc_func( param_type_list_pattern_match_type_list(context, &tpl_func_params, &target_func_params)?; - let tpl_return = tpl_func.get_ret(); - let target_return = target_func.get_ret(); - return_type_pattern_match_target_type(context, tpl_return, target_return)?; + let target_returns = target_func.get_return_row(); + for (i, tpl_return) in tpl_func.get_return_row().iter().enumerate() { + if let LuaType::Variadic(variadic) = tpl_return { + let target_rest = target_returns.get(i..).unwrap_or(&[]); + variadic_tpl_pattern_match(context, variadic, target_rest)?; + break; + } + + let target_return = target_returns.get(i).unwrap_or(&LuaType::Nil); + return_type_pattern_match_target_type(context, tpl_return, target_return)?; + } Ok(()) } @@ -804,13 +812,15 @@ pub fn variadic_tpl_pattern_match( tpl: &VariadicType, target_rest_types: &[LuaType], ) -> TplPatternMatchResult { + // Return rows preserve arity: `R...` matched against no values binds an + // empty row, not `nil`. match tpl { VariadicType::Base(base) => match base { LuaType::TplRef(tpl_ref) => { let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.substitutor.insert_multi_types(tpl_id, Vec::new()); } 1 => { // If the single argument is itself a multi-return (e.g. a function call @@ -820,7 +830,7 @@ pub fn variadic_tpl_pattern_match( LuaType::Variadic(variadic) => match variadic.deref() { VariadicType::Multi(types) => match types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, true); + context.substitutor.insert_multi_types(tpl_id, Vec::new()); } 1 => { context.substitutor.insert_type( @@ -863,7 +873,7 @@ pub fn variadic_tpl_pattern_match( let tpl_id = tpl_ref.get_tpl_id(); match target_rest_types.len() { 0 => { - context.substitutor.insert_type(tpl_id, LuaType::Nil, false); + context.substitutor.insert_multi_types(tpl_id, Vec::new()); } 1 => { context.substitutor.insert_type( @@ -1017,14 +1027,14 @@ fn try_handle_pairs_metamethod( .get_signature_index() .get(signature_id) .map(|s| s.get_return_type()), - LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), + LuaType::DocFunction(doc_func) => Some(doc_func.get_return_type()), _ => None, } .ok_or(InferFailReason::None)?; // 解析出迭代函数返回类型 let final_return_type = match meta_return { - LuaType::DocFunction(doc_func) => Some(doc_func.get_ret().clone()), + LuaType::DocFunction(doc_func) => Some(doc_func.get_return_type()), LuaType::Signature(signature_id) => context .db .get_signature_index() diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs index cd1360cf2..4bc0a15c7 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_call/mod.rs @@ -13,6 +13,7 @@ use crate::{ AsyncState, CacheEntry, DbIndex, InFiled, LuaFunctionType, LuaGenericType, LuaInstanceType, LuaIntersectionType, LuaOperatorMetaMethod, LuaOperatorOwner, LuaSignature, LuaSignatureId, LuaType, LuaTypeDeclId, LuaUnionType, TypeVisitTrait, VariadicType, + db_index::return_row::{merge_return_rows, return_type_to_row, row_to_multi_return_type}, }; use crate::{ InferGuardRef, @@ -33,6 +34,71 @@ mod infer_setmetatable; pub type InferCallFuncResult = Result, InferFailReason>; +pub(crate) fn infer_call_expr_list_tail( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: LuaCallExpr, + scalar_type: LuaType, +) -> InferResult { + // Keep the normal scalar path first so inline `---@as` bindings and casts + // still win. Only a scalar nil can be the placeholder for an empty row. + if !matches!(scalar_type, LuaType::Nil) { + return Ok(scalar_type); + } + + let tail_type = infer_call_expr_inner(db, cache, call_expr, true)?; + if let LuaType::Variadic(variadic) = &tail_type + && let VariadicType::Multi(types) = variadic.as_ref() + && types.is_empty() + { + Ok(tail_type) + } else { + Ok(scalar_type) + } +} + +fn infer_call_expr_inner( + db: &DbIndex, + cache: &mut LuaInferCache, + call_expr: LuaCallExpr, + preserve_empty_row: bool, +) -> InferResult { + if call_expr.is_require() { + return infer_require_call(db, cache, call_expr); + } else if call_expr.is_setmetatable() { + return infer_setmetatable_call(db, cache, call_expr); + } + + check_can_infer(db, cache, &call_expr)?; + + let prefix_expr = call_expr.get_prefix_expr().ok_or(InferFailReason::None)?; + let prefix_type = infer_expr(db, cache, prefix_expr)?; + let func_ty = infer_call_expr_func( + db, + cache, + call_expr.clone(), + prefix_type, + &InferGuard::new(), + None, + )?; + let ret_type = if preserve_empty_row { + row_to_multi_return_type(func_ty.get_return_row().to_vec()) + } else { + func_ty.get_return_type() + }; + + if !cache.is_no_flow() + && let Some(tree) = db.get_flow_index().get_flow_tree(&cache.get_file_id()) + && let Some(flow_id) = tree.get_flow_id(call_expr.get_syntax_id()) + && let Some(flow_ret_type) = + get_type_at_call_expr_inline_cast(db, cache, tree, call_expr, flow_id, ret_type.clone()) + { + return Ok(flow_ret_type); + } + + Ok(ret_type) +} + pub fn infer_call_expr_func( db: &DbIndex, cache: &mut LuaInferCache, @@ -112,7 +178,9 @@ pub fn infer_call_expr_func( false, true, vec![("...".to_string(), Some(LuaType::Unknown))], - LuaType::Variadic(VariadicType::Base(LuaType::Unknown).into()), + vec![LuaType::Variadic( + VariadicType::Base(LuaType::Unknown).into(), + )], ))), LuaType::Intersection(intersection) => infer_intersection( db, @@ -127,39 +195,44 @@ pub fn infer_call_expr_func( false, true, vec![], - LuaType::Any, + vec![LuaType::Any], ))), LuaType::Union(union) => infer_union(db, cache, union, call_expr.clone(), args_count), _ => Err(InferFailReason::None), }; - let result = if let Ok(func_ty) = result { - let func_ty = match func_ty.get_ret() { - LuaType::Call(_) => { - match instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) { - Ok(func_ty) => Arc::new(func_ty), - Err(_) => func_ty, + let result = match result { + Ok(func_ty) if func_ty.get_return_row().is_empty() => Ok(func_ty), + Ok(func_ty) => { + let func_ty = match func_ty.get_return_type() { + LuaType::Call(_) => { + instantiate_func_generic(db, cache, func_ty.as_ref(), call_expr.clone()) + .map(Arc::new) + .unwrap_or(func_ty) + } + _ => func_ty, + }; + + if func_ty.get_return_row().is_empty() { + Ok(func_ty) + } else { + let func_ret = func_ty.get_return_type(); + match func_ret { + LuaType::TypeGuard(_) => Ok(func_ty), + _ => unwrapp_return_type(db, cache, func_ret, call_expr).map(|new_ret| { + LuaFunctionType::new( + func_ty.get_async_state(), + func_ty.is_colon_define(), + func_ty.is_variadic(), + func_ty.get_params().to_vec(), + return_type_to_row(new_ret), + ) + .into() + }), } } - _ => func_ty, - }; - - let func_ret = func_ty.get_ret(); - match func_ret { - LuaType::TypeGuard(_) => Ok(func_ty), - _ => unwrapp_return_type(db, cache, func_ret.clone(), call_expr).map(|new_ret| { - LuaFunctionType::new( - func_ty.get_async_state(), - func_ty.is_colon_define(), - func_ty.is_variadic(), - func_ty.get_params().to_vec(), - new_ret, - ) - .into() - }), } - } else { - result + Err(err) => Err(err), }; match &result { @@ -532,8 +605,7 @@ fn infer_union( call_expr: LuaCallExpr, args_count: Option, ) -> InferCallFuncResult { - let mut returns = Vec::new(); - let mut first_func = None; + let mut matching_funcs = Vec::new(); let mut fallback_overloads = Vec::new(); let mut need_resolve = None; @@ -564,10 +636,7 @@ fn infer_union( args_count, ) { Ok(func) => { - returns.push(func.get_ret().clone()); - if first_func.is_none() { - first_func = Some(func); - } + matching_funcs.push(func); } Err(InferFailReason::RecursiveInfer) => { return Err(InferFailReason::RecursiveInfer); @@ -582,7 +651,7 @@ fn infer_union( } } - let Some(first_func) = first_func else { + if matching_funcs.is_empty() { if !fallback_overloads.is_empty() { let contains_tpl = fallback_overloads.iter().any(|func| func.contain_tpl()); let fallback_overloads = filter_callable_overloads_by_call_args( @@ -606,12 +675,18 @@ fn infer_union( return Err(need_resolve.unwrap_or(InferFailReason::None)); }; + let first_func = &matching_funcs[0]; + let return_rows = matching_funcs + .iter() + .map(|func| func.get_return_row()) + .collect::>(); + Ok(Arc::new(LuaFunctionType::new( first_func.get_async_state(), first_func.is_colon_define(), first_func.is_variadic(), first_func.get_params().to_vec(), - LuaType::from_vec(returns), + merge_return_rows(&return_rows), ))) } @@ -706,7 +781,8 @@ fn unwrapp_return_type( return Ok(ty.clone()); } - return Ok(ty.get_result_slot_type(0).unwrap_or(LuaType::Nil)); + let return_type = super::adjusted_result_slot_type(&ty, 0); + return unwrapp_return_type(db, cache, return_type, call_expr); } LuaType::SelfInfer => { if let Some(self_type) = infer_self_type(db, cache, &call_expr) { @@ -760,37 +836,7 @@ pub fn infer_call_expr( cache: &mut LuaInferCache, call_expr: LuaCallExpr, ) -> InferResult { - if call_expr.is_require() { - return infer_require_call(db, cache, call_expr); - } else if call_expr.is_setmetatable() { - return infer_setmetatable_call(db, cache, call_expr); - } - - check_can_infer(db, cache, &call_expr)?; - - let prefix_expr = call_expr.get_prefix_expr().ok_or(InferFailReason::None)?; - let prefix_type = infer_expr(db, cache, prefix_expr)?; - let ret_type = infer_call_expr_func( - db, - cache, - call_expr.clone(), - prefix_type, - &InferGuard::new(), - None, - )? - .get_ret() - .clone(); - - if !cache.is_no_flow() - && let Some(tree) = db.get_flow_index().get_flow_tree(&cache.get_file_id()) - && let Some(flow_id) = tree.get_flow_id(call_expr.get_syntax_id()) - && let Some(flow_ret_type) = - get_type_at_call_expr_inline_cast(db, cache, tree, call_expr, flow_id, ret_type.clone()) - { - return Ok(flow_ret_type); - } - - Ok(ret_type) + infer_call_expr_inner(db, cache, call_expr, false) } fn check_can_infer( @@ -903,6 +949,27 @@ mod tests { assert_eq!(ws.expr_ty("payload"), ws.ty("string")); } + #[test] + fn test_non_last_multi_return_unwraps_first_slot() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@return table, string + local function make() + end + + first, second = make(), 1 + "#, + ); + + let first = ws.expr_ty("first"); + assert!( + matches!(first, LuaType::TableConst(_)), + "expected first return slot to be materialized as a fresh table, got {:?}", + first + ); + } + #[test] fn test_union_call_ignores_unresolved_alias_member() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs index 288617c82..3c19fc0fb 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/infer_doc_type.rs @@ -482,21 +482,13 @@ fn infer_func_type(ctx: DocTypeInferContext<'_>, func: &LuaDocFuncType) -> LuaTy // since we don't have the same context as the analyzer. This is a simplification. let is_colon = false; - let return_type = if return_types.len() == 1 { - return_types[0].clone() - } else if return_types.len() > 1 { - LuaType::Variadic(VariadicType::Multi(return_types).into()) - } else { - LuaType::Nil - }; - LuaType::DocFunction( LuaFunctionType::new( async_state, is_colon, is_variadic, params_result, - return_type, + return_types, ) .into(), ) diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index eed7743ff..1a1659ef5 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -16,8 +16,8 @@ use emmylua_parser::{ LuaSyntaxId, LuaTableExpr, LuaVarExpr, NumberResult, }; use infer_binary::infer_binary_expr; -use infer_call::infer_call_expr; pub use infer_call::infer_call_expr_func; +use infer_call::{infer_call_expr, infer_call_expr_list_tail}; pub use infer_doc_type::{DocTypeInferContext, infer_doc_type}; pub use infer_fail_reason::InferFailReason; pub use infer_index::infer_index_expr; @@ -259,11 +259,47 @@ fn get_custom_type_operator( } } -pub fn infer_expr_list_types( +/// Infers a Lua expression list using assignment/call result adjustment. +/// +/// Only the final expression may contribute multiple values. Earlier values +/// are adjusted to one slot, and `var_count` caps how many final-expression +/// slots are pulled. +pub(crate) fn infer_expr_list_types( db: &DbIndex, cache: &mut LuaInferCache, exprs: &[LuaExpr], var_count: Option, + infer: F, +) -> Result, InferFailReason> +where + F: FnMut(&DbIndex, &mut LuaInferCache, LuaExpr) -> InferResult, +{ + infer_expr_list_types_with(db, cache, exprs, var_count, false, infer) +} + +/// Infers a Lua return expression list. +/// +/// Return rows use normal expression-list adjustment, but a final unbounded +/// variadic expression remains a variadic tail instead of being collapsed to +/// one scalar slot. +pub(crate) fn infer_return_expr_list_types( + db: &DbIndex, + cache: &mut LuaInferCache, + exprs: &[LuaExpr], + infer: F, +) -> Result, InferFailReason> +where + F: FnMut(&DbIndex, &mut LuaInferCache, LuaExpr) -> InferResult, +{ + infer_expr_list_types_with(db, cache, exprs, None, true, infer) +} + +fn infer_expr_list_types_with( + db: &DbIndex, + cache: &mut LuaInferCache, + exprs: &[LuaExpr], + var_count: Option, + preserve_variadic_tail: bool, mut infer: F, ) -> Result, InferFailReason> where @@ -277,45 +313,83 @@ where break; } + let expr_range = expr.get_range(); let expr_type = infer(db, cache, expr.clone())?; - if let Some(var_count) = var_count - && expr_type.contain_multi_return() + let expr_type = if idx + 1 == exprs.len() + && let LuaExpr::CallExpr(call_expr) = expr { - if idx < var_count { - for i in idx..var_count { - if let Some(typ) = expr_type.get_result_slot_type(i - idx) { - value_types.push((typ, expr.get_range())); - } else { - break; + infer_call_expr_list_tail(db, cache, call_expr.clone(), expr_type)? + } else { + expr_type + }; + if idx + 1 == exprs.len() && expr_type.contain_multi_return() { + if let Some(var_count) = var_count { + let value_count = value_types.len(); + if value_count < var_count { + for i in value_count..var_count { + if let Some(typ) = expr_type.get_result_slot_type(i - value_count) { + value_types.push((typ, expr_range)); + } else { + break; + } } } - } - - break; - } - - match expr_type { - LuaType::Variadic(variadic) => { - match variadic.deref() { - VariadicType::Base(base) => { - value_types.push((base.clone(), expr.get_range())); + } else { + match expr_type { + LuaType::Variadic(variadic) + if preserve_variadic_tail + && matches!(variadic.as_ref(), VariadicType::Base(_)) => + { + value_types.push((LuaType::Variadic(variadic), expr_range)); } - VariadicType::Multi(vecs) => { - for typ in vecs { - value_types.push((typ.clone(), expr.get_range())); + LuaType::Variadic(variadic) => match variadic.deref() { + VariadicType::Base(base) => { + value_types.push((base.clone(), expr_range)); } - } + VariadicType::Multi(vecs) => { + for typ in vecs { + value_types.push((typ.clone(), expr_range)); + } + } + }, + _ => value_types.push((expr_type, expr_range)), } - - break; } - _ => value_types.push((expr_type, expr.get_range())), + + break; } + + let expr_type = adjusted_result_slot_type(&expr_type, 0); + value_types.push((expr_type, expr_range)); } Ok(value_types) } +/// Returns the RHS expression index and result slot for assignment target +/// `target_idx`. +/// +/// In `a, b, c = x, y()`, target `c` is supplied by RHS expression `y()` at +/// result slot `1`. +pub(crate) fn assignment_rhs_source( + expr_count: usize, + target_idx: usize, +) -> Option<(usize, usize)> { + let last_expr_idx = expr_count.checked_sub(1)?; + Some(( + target_idx.min(last_expr_idx), + target_idx.saturating_sub(last_expr_idx), + )) +} + +/// Returns a concrete type for `slot` after Lua value adjustment. +/// +/// Multi-return values may have fewer slots than a caller asks for. In fixed +/// result positions, those exhausted slots are adjusted to `nil`. +pub(crate) fn adjusted_result_slot_type(expr_type: &LuaType, slot: usize) -> LuaType { + expr_type.get_result_slot_type(slot).unwrap_or(LuaType::Nil) +} + /// 推断值已经绑定的类型(不是推断值的类型). 例如从右值推断左值类型, 从调用参数推断函数参数类型参数类型 pub fn infer_bind_value_type( db: &DbIndex, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index e915915b1..02edb38e4 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -104,7 +104,7 @@ pub(super) fn get_type_at_call_expr_by_func( ) -> Result { match maybe_func { LuaType::DocFunction(f) => { - let return_type = f.get_ret(); + let return_type = f.get_return_type(); match return_type { LuaType::TypeGuard(_) => get_type_at_call_expr_by_type_guard( db, @@ -119,7 +119,7 @@ pub(super) fn get_type_at_call_expr_by_func( cache, var_ref_id, call_expr, - call, + &call, condition_flow, ), _ => Ok(ConditionFlowAction::Continue), @@ -223,14 +223,14 @@ fn get_type_guard_call_info( return Ok(None); }; - let mut return_type = func_type.get_ret().clone(); + let mut return_type = func_type.get_return_type(); if return_type.contain_tpl() { let Ok(inst_func) = cache.with_no_flow(|cache| { instantiate_func_generic(db, cache, func_type.as_ref(), call_expr) }) else { return Ok(None); }; - return_type = inst_func.get_ret().clone(); + return_type = inst_func.get_return_type(); } let LuaType::TypeGuard(guard) = return_type else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 78c464852..9e90631e5 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -5,6 +5,7 @@ use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk}; use crate::{ DbIndex, FlowId, FlowTree, InferFailReason, LuaDeclId, LuaFunctionType, LuaInferCache, LuaSignature, LuaType, TypeOps, + db_index::return_row::get_overload_row_slot, semantic::{ infer::{InferResult, VarRefId, narrow::narrow_down_type, try_infer_expr_no_flow}, instantiate_func_generic, @@ -484,20 +485,19 @@ fn collect_matching_correlated_types( } correlated_discriminant_call_expr_ids.insert(discriminant_call_expr_id); correlated_target_call_expr_ids.insert(target_ref.call_expr.get_syntax_id()); - correlated_candidate_types.extend(overload_rows.iter().map(|overload| { - LuaSignature::get_overload_row_slot(overload, target_ref.return_index) - })); + correlated_candidate_types.extend( + overload_rows + .iter() + .map(|overload| get_overload_row_slot(overload, target_ref.return_index)), + ); matching_target_types.extend(overload_rows.iter().filter_map(|overload| { let discriminant_type = - LuaSignature::get_overload_row_slot(overload, discriminant_ref.return_index); + get_overload_row_slot(overload, discriminant_ref.return_index); if !TypeOps::Intersect .apply(db, &discriminant_type, narrowed_discriminant_type) .is_never() { - return Some(LuaSignature::get_overload_row_slot( - overload, - target_ref.return_index, - )); + return Some(get_overload_row_slot(overload, target_ref.return_index)); } None @@ -521,7 +521,7 @@ fn collect_matching_correlated_types( unmatched_target_types.extend( return_rows .iter() - .map(|row| LuaSignature::get_overload_row_slot(row, target_ref.return_index)), + .map(|row| get_overload_row_slot(row, target_ref.return_index)), ); } @@ -566,9 +566,9 @@ fn instantiate_return_rows( call_expr: LuaCallExpr, signature: &LuaSignature, ) -> Vec> { - let mut instantiate_return_type = |return_type: LuaType| { - if !return_type.contain_tpl() { - return return_type; + let mut instantiate_return_row = |row: Vec| { + if !row.iter().any(|return_type| return_type.contain_tpl()) { + return row; } let func = LuaFunctionType::new( @@ -576,28 +576,23 @@ fn instantiate_return_rows( signature.is_colon_define, signature.is_vararg, signature.get_type_params(), - return_type.clone(), + row.clone(), ); match cache .with_no_flow(|cache| instantiate_func_generic(db, cache, &func, call_expr.clone())) { - Ok(instantiated) => instantiated.get_ret().clone(), - Err(_) => return_type, + Ok(instantiated) => instantiated.get_return_row().to_vec(), + Err(_) => row, } }; if signature.return_overloads.is_empty() { - let instantiated_return_type = instantiate_return_type(signature.get_return_type()); - return vec![LuaSignature::return_type_to_row(instantiated_return_type)]; + return vec![instantiate_return_row(signature.get_return_row())]; } - let mut rows = Vec::with_capacity(signature.return_overloads.len()); - for overload in &signature.return_overloads { - let type_refs = &overload.type_refs; - let overload_return_type = LuaSignature::row_to_return_type(type_refs.to_vec()); - let instantiated_return_type = instantiate_return_type(overload_return_type); - rows.push(LuaSignature::return_type_to_row(instantiated_return_type)); - } - - rows + signature + .return_overloads + .iter() + .map(|overload| instantiate_return_row(overload.type_refs.clone())) + .collect() } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 702b2ab73..885886a5f 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -9,6 +9,7 @@ use crate::{ CacheEntry, DbIndex, FlowId, FlowNode, FlowNodeKind, FlowTree, InferFailReason, LuaDeclId, LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, check_type_compact, semantic::{ + adjusted_result_slot_type, assignment_rhs_source, cache::{FlowAssignmentInfo, FlowMode, FlowVarCache}, infer::{ InferResult, VarRefId, @@ -654,10 +655,7 @@ impl<'a> FlowTypeEngine<'a> { Err(err) => return self.fail_query(&query, err), }; - let Some(init_type) = expr_type.get_result_slot_type(0) else { - return self.fail_query(&query, fail_reason); - }; - + let init_type = adjusted_result_slot_type(&expr_type, 0); Ok(self.finish_walk(walk, init_type)) } FlowExprReplay::Condition { @@ -988,9 +986,8 @@ impl<'a> FlowTypeEngine<'a> { .filter(|tc| tc.is_doc()) .map(|tc| tc.as_type().clone()); - if let Some(last_expr_idx) = assignment_info.exprs.len().checked_sub(1) { - let expr_idx = i.min(last_expr_idx); - let result_slot = i.saturating_sub(last_expr_idx); + if let Some((expr_idx, result_slot)) = assignment_rhs_source(assignment_info.exprs.len(), i) + { let expr = assignment_info.exprs[expr_idx].clone(); let replay_query = FlowReplayQuery::new( self.db, @@ -1023,9 +1020,7 @@ impl<'a> FlowTypeEngine<'a> { replay_query: FlowReplayQuery, ) -> Result { let expr_type = match replay_query.replay_type(self.db, self.cache) { - Ok(Some(expr_type)) => expr_type - .get_result_slot_type(result_slot) - .unwrap_or(LuaType::Nil), + Ok(Some(expr_type)) => adjusted_result_slot_type(&expr_type, result_slot), Ok(None) => LuaType::Unknown, Err(err) => { return self.finish_assignment_expr_error( diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs index fdf9f6087..6565043fa 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/var_ref_id.rs @@ -128,7 +128,7 @@ fn get_call_expr_var_ref_id( let maybe_func = try_infer_expr_no_flow(db, cache, prefix_expr.clone()).ok()??; let ret = match maybe_func { - LuaType::DocFunction(f) => f.get_ret().clone(), + LuaType::DocFunction(f) => f.get_return_type(), LuaType::Signature(signature_id) => db .get_signature_index() .get(&signature_id)? diff --git a/crates/emmylua_code_analysis/src/semantic/mod.rs b/crates/emmylua_code_analysis/src/semantic/mod.rs index b9cd692f9..437e8995d 100644 --- a/crates/emmylua_code_analysis/src/semantic/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/mod.rs @@ -57,7 +57,8 @@ pub use infer::InferFailReason; pub use infer::infer_call_expr_func; pub use infer::infer_param; pub(crate) use infer::try_infer_expr_for_index; -pub(crate) use infer::{infer_expr, try_infer_expr_no_flow}; +pub(crate) use infer::{adjusted_result_slot_type, assignment_rhs_source}; +pub(crate) use infer::{infer_expr, infer_return_expr_list_types, try_infer_expr_no_flow}; use overload_resolve::resolve_signature; pub use semantic_info::SemanticDeclLevel; pub use type_check::{TypeCheckFailReason, TypeCheckResult}; diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs index 7e2217f27..815da136c 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/resolve_signature_by_args.rs @@ -229,7 +229,7 @@ pub fn resolve_signature_by_args( return Err(InferFailReason::None); }; - if remaining_funcs.all(|func| func.get_ret() == first.get_ret()) { + if remaining_funcs.all(|func| func.get_return_row() == first.get_return_row()) { Ok(first) } else { Err(InferFailReason::None) diff --git a/crates/emmylua_doc_cli/src/markdown_generator/render.rs b/crates/emmylua_doc_cli/src/markdown_generator/render.rs index 44682ebf1..6bc8fc4bb 100644 --- a/crates/emmylua_doc_cli/src/markdown_generator/render.rs +++ b/crates/emmylua_doc_cli/src/markdown_generator/render.rs @@ -77,9 +77,9 @@ fn render_doc_function_type( }) .collect::>(); - let ret_type = lua_func.get_ret(); + let ret_type = lua_func.get_return_type(); - let ret_strs = render_typ(db, ret_type, RenderLevel::Documentation); + let ret_strs = render_typ(db, &ret_type, RenderLevel::Documentation); let mut result = String::new(); result.push_str("```lua\n"); diff --git a/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs b/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs index d824fb4a8..69881e4f2 100644 --- a/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs +++ b/crates/emmylua_ls/src/handlers/completion/add_completions/mod.rs @@ -130,13 +130,13 @@ pub fn get_detail( } _ => {} } - let ret_type = f.get_ret(); - let rets_detail = match ret_type { + let ret_type = f.get_return_type(); + let rets_detail = match &ret_type { LuaType::Nil => "".to_string(), _ => { let type_detail = humanize_type( builder.semantic_model.get_db(), - ret_type, + &ret_type, RenderLevel::Minimal, ); format!("-> {}", type_detail) diff --git a/crates/emmylua_ls/src/handlers/hover/function/mod.rs b/crates/emmylua_ls/src/handlers/hover/function/mod.rs index 2402e078e..7a3f3dcd7 100644 --- a/crates/emmylua_ls/src/handlers/hover/function/mod.rs +++ b/crates/emmylua_ls/src/handlers/hover/function/mod.rs @@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc, vec}; use emmylua_code_analysis::{ AsyncState, DbIndex, InferGuard, LuaDocReturnInfo, LuaDocReturnOverloadInfo, LuaFunctionType, LuaMember, LuaMemberOwner, LuaSemanticDeclId, LuaSignature, LuaType, RenderLevel, - TypeSubstitutor, VariadicType, humanize_type, infer_call_expr_func, instantiate_doc_function, + TypeSubstitutor, humanize_type, infer_call_expr_func, instantiate_doc_function, instantiate_func_generic, try_extract_signature_id_from_field, }; @@ -102,7 +102,7 @@ fn build_function_call_hover( signature.is_colon_define, signature.is_vararg, signature.get_type_params(), - signature.get_return_type(), + signature.get_return_row(), ); let instantiated_signature = instantiate_func_generic( db, @@ -283,7 +283,7 @@ fn process_function_type( signature.is_colon_define, signature.is_vararg, signature.get_type_params(), - signature.get_return_type(), + signature.get_return_row(), )); new_overloads.insert(0, fake_doc_function.clone()); let mut contents = Vec::with_capacity(new_overloads.len()); @@ -473,28 +473,17 @@ fn instantiate_call_return_overloads( .return_overloads .iter() .map(|row| { - let row_return_type = match row.type_refs.len() { - 0 => LuaType::Nil, - 1 => row.type_refs[0].clone(), - _ => LuaType::Variadic(VariadicType::Multi(row.type_refs.clone()).into()), - }; let row_function = LuaFunctionType::new( signature.async_state, signature.is_colon_define, signature.is_vararg, signature.get_type_params(), - row_return_type, + row.type_refs.clone(), ); let instantiated_row = instantiate_func_generic(db, &mut cache, &row_function, call_expr.clone()) .ok() - .map(|func| match func.get_ret() { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Multi(types) => types.clone(), - VariadicType::Base(_) => vec![LuaType::Variadic(variadic.clone())], - }, - typ => vec![typ.clone()], - }) + .map(|func| func.get_return_row().to_vec()) .unwrap_or_else(|| row.type_refs.clone()); LuaDocReturnOverloadInfo { @@ -506,31 +495,15 @@ fn instantiate_call_return_overloads( } fn convert_function_return_to_docs(func: &LuaFunctionType) -> Vec { - match func.get_ret() { - LuaType::Variadic(variadic) => match variadic.as_ref() { - VariadicType::Base(base) => vec![LuaDocReturnInfo { - name: None, - type_ref: base.clone(), - description: None, - attributes: None, - }], - VariadicType::Multi(types) => types - .iter() - .map(|ty| LuaDocReturnInfo { - name: None, - type_ref: ty.clone(), - description: None, - attributes: None, - }) - .collect(), - }, - _ => vec![LuaDocReturnInfo { + func.get_return_row() + .iter() + .map(|ty| LuaDocReturnInfo { name: None, - type_ref: func.get_ret().clone(), + type_ref: ty.clone(), description: None, attributes: None, - }], - } + }) + .collect() } fn format_function_type( diff --git a/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs b/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs index 889f7f52d..fd47b1927 100644 --- a/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs +++ b/crates/emmylua_ls/src/handlers/inlay_hint/build_inlay_hint.rs @@ -542,7 +542,7 @@ fn set_meta_call_part( LuaStat::can_cast(parent.kind().into()) && !matches!(call_expr.get_prefix_expr()?, LuaExpr::CallExpr(_)) && semantic_model - .type_check(call_func.get_ret(), &target_type) + .type_check(&call_func.get_return_type(), &target_type) .is_ok() }; diff --git a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs index 4e5cef58f..e1bcdfa3d 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/build_signature_helper.rs @@ -136,11 +136,12 @@ fn build_doc_function_signature_help( current_idx = params.len() - 1; } + let return_type = func_type.get_return_type(); let label = build_function_label( builder, ¶m_infos, func_type.is_method(builder.semantic_model, None), - func_type.get_ret(), + &return_type, ); let documentation = description.map(|description| { diff --git a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs index b508f9748..5b09c587f 100644 --- a/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs +++ b/crates/emmylua_ls/src/handlers/signature_helper/signature_helper_builder.rs @@ -162,11 +162,12 @@ impl<'a> SignatureHelperBuilder<'a> { } _ => {} } + let return_type = func.get_return_type(); self.best_call_function_label = build_function_label( self, &self.params_info, func.is_method(self.semantic_model, None), - func.get_ret(), + &return_type, ); Some(()) diff --git a/crates/emmylua_ls/src/handlers/test/completion_test.rs b/crates/emmylua_ls/src/handlers/test/completion_test.rs index 5ee0ad588..f7aeaee39 100644 --- a/crates/emmylua_ls/src/handlers/test/completion_test.rs +++ b/crates/emmylua_ls/src/handlers/test/completion_test.rs @@ -1241,7 +1241,7 @@ mod tests { VirtualCompletionItem { label: "set".to_string(), kind: CompletionItemKind::FUNCTION, - label_detail: Some("(self) -> nil".to_string()), + label_detail: Some("(self)".to_string()), }, ], )); @@ -1843,7 +1843,7 @@ mod tests { VirtualCompletionItem { label: "init".to_string(), kind: CompletionItemKind::FUNCTION, - label_detail: Some("(self) -> nil".to_string()), + label_detail: Some("(self)".to_string()), }, ], ));