Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ where
pub(in crate::compilation::analyzer) fn analyze_func_body_missing_return_flags_with<F>(
body: LuaBlock,
infer_expr_type: &mut F,
) -> Result<(bool, bool, bool), InferFailReason>
) -> Result<(bool, bool), InferFailReason>
where
F: FnMut(&LuaExpr) -> Result<LuaType, InferFailReason>,
{
let flow = analyze_block_returns(body, infer_expr_type)?;
Ok((flow.can_fall_through, flow.can_break, flow.is_infinite))
Ok((flow.can_fall_through, flow.can_break))
}

fn analyze_block_returns<F>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use unresolve::UnResolve;
pub(super) fn analyze_func_body_missing_return_flags_with<F>(
body: LuaBlock,
infer_expr_type: &mut F,
) -> Result<(bool, bool, bool), InferFailReason>
) -> Result<(bool, bool), InferFailReason>
where
F: FnMut(&LuaExpr) -> Result<LuaType, InferFailReason>,
{
Expand Down
2 changes: 1 addition & 1 deletion crates/emmylua_code_analysis/src/compilation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use emmylua_parser::{LuaBlock, LuaExpr};
pub(crate) fn analyze_func_body_missing_return_flags_with<F>(
body: LuaBlock,
infer_expr_type: &mut F,
) -> Result<(bool, bool, bool), InferFailReason>
) -> Result<(bool, bool), InferFailReason>
where
F: FnMut(&LuaExpr) -> Result<LuaType, InferFailReason>,
{
Expand Down
78 changes: 78 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,84 @@ mod test {
assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer"));
}

#[test]
fn test_pcall_narrows_type_guarded_callable_return_after_error_guard() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@param a string|fun(): integer
local function foo(a)
if type(a) == "string" then
return
end

local ok, result = pcall(a)
if not ok then
return
end

narrowed = result
end
"#,
);

let narrowed = ws.expr_ty("narrowed");
assert_eq!(ws.humanize_type(narrowed), "integer");
}

#[test]
fn test_pcall_narrows_type_guarded_callable_return_with_forwarded_arg() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@param a string|fun(value: integer): string
local function foo(a)
if type(a) == "string" then
return
end

local ok, result = pcall(a, 1)
if not ok then
return
end

narrowed = result
end
"#,
);

let narrowed = ws.expr_ty("narrowed");
assert_eq!(ws.humanize_type(narrowed), "string");
}

#[test]
fn test_pcall_narrows_type_guarded_callable_return_with_table_arg() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@param cb string|fun(a: {}): integer
local function foo(cb)
if type(cb) == "string" then
return
end

local ok, result = pcall(cb, {})
if not ok then
return
end

narrowed = result
end
"#,
);

let narrowed = ws.expr_ty("narrowed");
assert_eq!(ws.humanize_type(narrowed), "integer");
}

#[test]
fn test_pcall_any_callable_splits_success_unknown_and_failure_string() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,19 @@ fn check_missing_return(
// 检测缺少返回语句需要处理 if while
if min_expected_return_count > 0 {
let range = if let Some(block) = closure_expr.get_block() {
let (can_fall_through, can_break, is_infinite) =
analyze_func_body_missing_return_flags_with(
block.clone(),
&mut |expr: &LuaExpr| {
Ok(semantic_model
.infer_expr(expr.clone())
.unwrap_or(LuaType::Unknown))
},
)
.ok()?;

// `MissingReturn` currently ignores runtime-dependent divergence if
// a later `return` is still reachable.
if !can_fall_through && !can_break && !is_infinite {
let (can_fall_through, can_break) = analyze_func_body_missing_return_flags_with(
block.clone(),
&mut |expr: &LuaExpr| {
Ok(semantic_model
.infer_expr(expr.clone())
.unwrap_or(LuaType::Unknown))
},
)
.ok()?;

// Non-terminating paths satisfy `MissingReturn`; only paths that
// can leave the function body without returning should warn.
if !can_fall_through && !can_break {
return Some(());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,22 @@ mod tests {
));
}

#[test]
fn test_assert_optional_return_is_not_redundant() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

assert!(ws.has_no_diagnostic(
DiagnosticCode::RedundantReturnValue,
r#"
--- @return string
function foo()
local res --- @type string?
return assert(res)
end
"#
));
}

#[test]
fn test_not_return_anno() {
let mut ws = VirtualWorkspace::new();
Expand Down Expand Up @@ -290,7 +306,7 @@ mod tests {
"#
));

assert!(!ws.has_no_diagnostic(
assert!(ws.has_no_diagnostic(
DiagnosticCode::MissingReturn,
r#"
local A
Expand Down Expand Up @@ -724,6 +740,25 @@ mod tests {
);
}

#[test]
fn test_missing_return_accepts_non_terminating_truthy_while() {
assert_missing_return_ok(
r#"
--- @param ready boolean
--- @return string
function foo(ready)
while true do
if ready then
return 'ready'
end
end

error('unreachable')
end
"#,
);
}

#[test]
fn test_missing_return_accepts_infinite_repeat_with_break_before_return() {
assert_missing_return_ok(
Expand All @@ -743,8 +778,8 @@ mod tests {
}

#[test]
fn test_missing_return_rejects_dynamic_while_with_infinite_body_before_return() {
assert_missing_return_error(
fn test_missing_return_accepts_dynamic_while_with_infinite_body_before_return() {
assert_missing_return_ok(
r#"
---@return number
local function foo(a)
Expand All @@ -760,8 +795,8 @@ mod tests {
}

#[test]
fn test_missing_return_rejects_dynamic_while_with_break_or_infinite_body_before_return() {
assert_missing_return_error(
fn test_missing_return_accepts_dynamic_while_with_break_or_infinite_body_before_return() {
assert_missing_return_ok(
r#"
---@return number
local function foo(a, b)
Expand All @@ -781,8 +816,8 @@ mod tests {
}

#[test]
fn test_missing_return_rejects_stalling_numeric_for_before_return() {
assert_missing_return_error(
fn test_missing_return_accepts_non_terminating_numeric_for_before_return() {
assert_missing_return_ok(
r#"
---@return number
local function foo()
Expand All @@ -798,8 +833,8 @@ mod tests {
}

#[test]
fn test_missing_return_rejects_stalling_generic_for_before_return() {
assert_missing_return_error(
fn test_missing_return_accepts_non_terminating_generic_for_before_return() {
assert_missing_return_ok(
r#"
local function iter(_, done)
if done then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,56 @@ mod tests {
));
}

#[test]
fn test_pcall_return_after_type_guard() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

assert!(ws.has_no_diagnostic(
DiagnosticCode::ReturnTypeMismatch,
r#"
--- @param a string|fun(): integer
--- @return integer?
function foo(a)
if type(a) == 'string' then
return
end

local ok, result = pcall(a)
if not ok then
return
end

return result
end
"#
));
}

#[test]
fn test_pcall_return_after_type_guard_with_table_arg() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

assert!(ws.has_no_diagnostic(
DiagnosticCode::ReturnTypeMismatch,
r#"
--- @param cb string|fun(a: {}): integer
--- @return integer?
function foo(cb)
if type(cb) == 'string' then
return
end

local ok, result = pcall(cb, {})
if not ok then
return
end

return result
end
"#
));
}

#[test]
fn test_variadic_return_type_mismatch() {
let mut ws = VirtualWorkspace::new();
Expand Down Expand Up @@ -689,4 +739,37 @@ mod tests {
"#
));
}

#[test]
fn test_asserted_array_member_return_field() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
let code = r#"
--- @return { a: integer }
function foo()
local arr --- @type integer[]
local i --- @type integer?
local a --- @type integer?
i = _ --[[@as integer]]
a = assert(arr[i])
return { a = a }
end
"#;
assert!(ws.has_no_diagnostic(DiagnosticCode::ReturnTypeMismatch, code));
assert!(ws.has_no_diagnostic(DiagnosticCode::AssignTypeMismatch, code));
}

#[test]
fn test_and_or_function_guard_return() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
assert!(ws.has_no_diagnostic(
DiagnosticCode::ReturnTypeMismatch,
r#"
--- @param f string|(fun():string)
--- @return string
function foo(f)
return type(f) == 'function' and f() or f
end
"#
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -873,4 +873,31 @@ mod test {
"#
));
}

#[test]
fn test_array_index_with_integer_literal_union() {
let mut ws = VirtualWorkspace::new();
let code = r#"
---@alias IntegerPartIndex
---| 1
---| 2

---@alias NumericPartIndex
---| 1
---| number

local parts --- @type string[]
local id --- @type 1|2
local alias_id --- @type IntegerPartIndex
local numeric_id --- @type NumericPartIndex
result = parts[id]
alias_result = parts[alias_id]
numeric_result = parts[numeric_id]
"#;

assert!(ws.has_no_diagnostic(DiagnosticCode::UndefinedField, code));
assert_eq!(ws.expr_ty("result"), ws.ty("string?"));
assert_eq!(ws.expr_ty("alias_result"), ws.ty("string?"));
assert_eq!(ws.expr_ty("numeric_result"), ws.ty("string?"));
}
}
Loading
Loading