diff --git a/mypy/checker.py b/mypy/checker.py index 8775f1ddef29..aade05dfc3d5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6558,6 +6558,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa operands = [collapse_walrus(x) for x in node.operands] operand_types = [] narrowable_operand_index_to_hash = {} + narrowable_operand_hash_to_index = {} for i, expr in enumerate(operands): if not self.has_type(expr): return {}, {} @@ -6582,6 +6583,7 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa h = literal_hash(expr) if h is not None: narrowable_operand_index_to_hash[i] = h + narrowable_operand_hash_to_index[h] = i # Step 2: Group operands chained by either the 'is' or '==' operands # together. For all other operands, we keep them in groups of size 2. @@ -6673,6 +6675,18 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa partial_type_maps.append((if_map, else_map)) + # Chained comparisons are conjunctions evaluated left-to-right. Feed what we learned + # from earlier true comparisons into later comparisons, similarly to `and`. + if len(simplified_operator_list) > 1: + for expr, expr_type in if_map.items(): + h = literal_hash(expr) + if h is None or h not in narrowable_operand_hash_to_index: + continue + operand_index = narrowable_operand_hash_to_index[h] + operand_types[operand_index] = meet_types( + operand_types[operand_index], expr_type + ) + # If we have found non-trivial restrictions from the regular comparisons, # then return soon. Otherwise try to infer restrictions involving `len(x)`. # TODO: support regular and len() narrowing in the same chain. diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 982a86e38edd..5ca6d31cce30 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -3264,7 +3264,7 @@ def bad_but_should_pass(has_key: bool, key: bool, s: tuple[bool, ...]) -> None: reveal_type(key) # N: Revealed type is "builtins.bool" [builtins fixtures/primitives.pyi] -[case testNarrowChainedComparisonMeet] +[case testNarrowChainedComparisonMeetAndForwardPropagation] # flags: --strict-equality --warn-unreachable from __future__ import annotations from typing import Any @@ -3272,7 +3272,7 @@ from typing import Any def f1(a: str | None, b: str | None) -> None: if None is not a == b: reveal_type(a) # N: Revealed type is "builtins.str" - reveal_type(b) # N: Revealed type is "builtins.str | None" + reveal_type(b) # N: Revealed type is "builtins.str" if (None is not a) and (a == b): reveal_type(a) # N: Revealed type is "builtins.str" @@ -3290,11 +3290,33 @@ def f2(a: Any | None, b: str | None) -> None: def f3(a: str | None, b: Any | None) -> None: if None is not a == b: reveal_type(a) # N: Revealed type is "builtins.str" - reveal_type(b) # N: Revealed type is "Any | builtins.str | None" + reveal_type(b) # N: Revealed type is "Any | builtins.str" if (None is not a) and (a == b): reveal_type(a) # N: Revealed type is "builtins.str" reveal_type(b) # N: Revealed type is "Any | builtins.str" + +def f4(a: str | None, b: str | None, c: str | None) -> None: + if None is not a == b == c: + reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(c) # N: Revealed type is "builtins.str" + + if (None is not a) and (a == b) and (b == c): + reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(c) # N: Revealed type is "builtins.str" + +def f5(pair: tuple[None, int] | tuple[str, str], other: str | None) -> None: + if None is not pair[0] == other: + reveal_type(pair[0]) # N: Revealed type is "builtins.str" + reveal_type(pair) # N: Revealed type is "tuple[builtins.str, builtins.str]" + reveal_type(other) # N: Revealed type is "builtins.str" + + if (None is not pair[0]) and (pair[0] == other): + reveal_type(pair[0]) # N: Revealed type is "builtins.str" + reveal_type(pair) # N: Revealed type is "tuple[builtins.str, builtins.str]" + reveal_type(other) # N: Revealed type is "builtins.str" [builtins fixtures/primitives.pyi] [case testNarrowTypeObject]