From ad7dc0e30c5108d7116121bf5cce1bd24bf667c5 Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Thu, 2 Apr 2026 20:40:16 -0700 Subject: [PATCH] Fix match statement semantics for "or" pattern Fixes https://github.com/mypyc/mypyc/issues/1166 Authored by Codex --- mypyc/irbuild/match.py | 13 ++- mypyc/test-data/irbuild-match.test | 133 ++++++++++++++++++----------- mypyc/test-data/run-match.test | 21 +++++ 3 files changed, 111 insertions(+), 56 deletions(-) diff --git a/mypyc/irbuild/match.py b/mypyc/irbuild/match.py index c2ca9cfd32ff7..7b08cf1437caa 100644 --- a/mypyc/irbuild/match.py +++ b/mypyc/irbuild/match.py @@ -108,20 +108,25 @@ def visit_value_pattern(self, pattern: ValuePattern) -> None: self.builder.add_bool_branch(cond, self.code_block, self.next_block) def visit_or_pattern(self, pattern: OrPattern) -> None: - backup_block = self.next_block - self.next_block = BasicBlock() + code_block = self.code_block + next_block = self.next_block for p in pattern.patterns: + self.code_block = BasicBlock() + self.next_block = BasicBlock() + # Hack to ensure the as pattern is bound to each pattern in the # "or" pattern, but not every subpattern backup = self.as_pattern p.accept(self) self.as_pattern = backup + self.builder.activate_block(self.code_block) + self.builder.goto(code_block) self.builder.activate_block(self.next_block) - self.next_block = BasicBlock() - self.next_block = backup_block + self.code_block = code_block + self.next_block = next_block self.builder.goto(self.next_block) def visit_class_pattern(self, pattern: ClassPattern) -> None: diff --git a/mypyc/test-data/irbuild-match.test b/mypyc/test-data/irbuild-match.test index 1e84c385100a7..4c9287993b311 100644 --- a/mypyc/test-data/irbuild-match.test +++ b/mypyc/test-data/irbuild-match.test @@ -48,13 +48,17 @@ def f(): r8, r9 :: object L0: r0 = int_eq 246, 246 - if r0 goto L3 else goto L1 :: bool + if r0 goto L1 else goto L2 :: bool L1: - r1 = int_eq 246, 912 - if r1 goto L3 else goto L2 :: bool + goto L5 L2: - goto L4 + r1 = int_eq 246, 912 + if r1 goto L3 else goto L4 :: bool L3: + goto L5 +L4: + goto L6 +L5: r2 = 'matched' r3 = builtins :: module r4 = 'print' @@ -63,9 +67,9 @@ L3: r7 = load_address r6 r8 = PyObject_Vectorcall(r5, r7, 1, 0) keep_alive r2 - goto L5 -L4: -L5: + goto L7 +L6: +L7: r9 = box(None, 1) return r9 @@ -86,19 +90,27 @@ def f(): r10, r11 :: object L0: r0 = int_eq 2, 2 - if r0 goto L5 else goto L1 :: bool + if r0 goto L1 else goto L2 :: bool L1: - r1 = int_eq 2, 4 - if r1 goto L5 else goto L2 :: bool + goto L9 L2: - r2 = int_eq 2, 6 - if r2 goto L5 else goto L3 :: bool + r1 = int_eq 2, 4 + if r1 goto L3 else goto L4 :: bool L3: - r3 = int_eq 2, 8 - if r3 goto L5 else goto L4 :: bool + goto L9 L4: - goto L6 + r2 = int_eq 2, 6 + if r2 goto L5 else goto L6 :: bool L5: + goto L9 +L6: + r3 = int_eq 2, 8 + if r3 goto L7 else goto L8 :: bool +L7: + goto L9 +L8: + goto L10 +L9: r4 = 'matched' r5 = builtins :: module r6 = 'print' @@ -107,9 +119,9 @@ L5: r9 = load_address r8 r10 = PyObject_Vectorcall(r7, r9, 1, 0) keep_alive r4 - goto L7 -L6: -L7: + goto L11 +L10: +L11: r11 = box(None, 1) return r11 @@ -280,16 +292,20 @@ L1: r6 = load_address r5 r7 = PyObject_Vectorcall(r4, r6, 1, 0) keep_alive r1 - goto L9 + goto L11 L2: r8 = int_eq 246, 4 - if r8 goto L5 else goto L3 :: bool + if r8 goto L3 else goto L4 :: bool L3: - r9 = int_eq 246, 6 - if r9 goto L5 else goto L4 :: bool + goto L7 L4: - goto L6 + r9 = int_eq 246, 6 + if r9 goto L5 else goto L6 :: bool L5: + goto L7 +L6: + goto L8 +L7: r10 = 'here 2 | 3' r11 = builtins :: module r12 = 'print' @@ -298,11 +314,11 @@ L5: r15 = load_address r14 r16 = PyObject_Vectorcall(r13, r15, 1, 0) keep_alive r10 - goto L9 -L6: + goto L11 +L8: r17 = int_eq 246, 246 - if r17 goto L7 else goto L8 :: bool -L7: + if r17 goto L9 else goto L10 :: bool +L9: r18 = 'here 123' r19 = builtins :: module r20 = 'print' @@ -311,9 +327,9 @@ L7: r23 = load_address r22 r24 = PyObject_Vectorcall(r21, r23, 1, 0) keep_alive r18 - goto L9 -L8: -L9: + goto L11 +L10: +L11: r25 = box(None, 1) return r25 @@ -456,15 +472,19 @@ def f(): r10, r11 :: object L0: r0 = int_eq 2, 2 - if r0 goto L3 else goto L1 :: bool + if r0 goto L1 else goto L2 :: bool L1: + goto L5 +L2: r1 = load_address PyLong_Type r2 = object 1 r3 = CPy_TypeCheck(r2, r1) - if r3 goto L3 else goto L2 :: bool -L2: - goto L4 + if r3 goto L3 else goto L4 :: bool L3: + goto L5 +L4: + goto L6 +L5: r4 = 'matched' r5 = builtins :: module r6 = 'print' @@ -473,9 +493,9 @@ L3: r9 = load_address r8 r10 = PyObject_Vectorcall(r7, r9, 1, 0) keep_alive r4 - goto L5 -L4: -L5: + goto L7 +L6: +L7: r11 = box(None, 1) return r11 @@ -532,15 +552,19 @@ L0: r0 = int_eq 2, 2 r1 = object 1 x = r1 - if r0 goto L3 else goto L1 :: bool + if r0 goto L1 else goto L2 :: bool L1: + goto L5 +L2: r2 = int_eq 2, 4 r3 = object 2 x = r3 - if r2 goto L3 else goto L2 :: bool -L2: - goto L4 + if r2 goto L3 else goto L4 :: bool L3: + goto L5 +L4: + goto L6 +L5: r4 = builtins :: module r5 = 'print' r6 = CPyObject_GetAttr(r4, r5) @@ -548,9 +572,9 @@ L3: r8 = load_address r7 r9 = PyObject_Vectorcall(r6, r8, 1, 0) keep_alive x - goto L5 -L4: -L5: + goto L7 +L6: +L7: r10 = box(None, 1) return r10 @@ -809,7 +833,7 @@ L0: r1 = PyObject_IsInstance(x, r0) r2 = r1 >= 0 :: signed r3 = truncate r1: i32 to builtins.bool - if r3 goto L1 else goto L5 :: bool + if r3 goto L1 else goto L7 :: bool L1: r4 = 'num' r5 = CPyObject_GetAttr(x, r4) @@ -818,17 +842,21 @@ L1: r8 = PyObject_IsTrue(r7) r9 = r8 >= 0 :: signed r10 = truncate r8: i32 to builtins.bool - if r10 goto L4 else goto L2 :: bool + if r10 goto L2 else goto L3 :: bool L2: + goto L6 +L3: r11 = object 2 r12 = PyObject_RichCompare(r5, r11, 2) r13 = PyObject_IsTrue(r12) r14 = r13 >= 0 :: signed r15 = truncate r13: i32 to builtins.bool - if r15 goto L4 else goto L3 :: bool -L3: - goto L5 + if r15 goto L4 else goto L5 :: bool L4: + goto L6 +L5: + goto L7 +L6: r16 = 'matched' r17 = builtins :: module r18 = 'print' @@ -837,11 +865,12 @@ L4: r21 = load_address r20 r22 = PyObject_Vectorcall(r19, r21, 1, 0) keep_alive r16 - goto L6 -L5: -L6: + goto L8 +L7: +L8: r23 = box(None, 1) return r23 + [case testAsPatternDoesntBleedIntoSubPatterns_python3_10] class C: __match_args__ = ("a", "b") diff --git a/mypyc/test-data/run-match.test b/mypyc/test-data/run-match.test index 7b7ad9a4342ce..b24552072a655 100644 --- a/mypyc/test-data/run-match.test +++ b/mypyc/test-data/run-match.test @@ -230,6 +230,27 @@ test 21 ('') test 21 (' as well') test sequence final test final + +[case testMatchOrSequencePattern_python3_10] +def f(x: tuple[str, str]) -> str: + match x: + case ("X", "Y") | ("X", "Z"): + return "THERE" + case _: + return "OTHER" + +[file driver.py] +from native import f + +print(f(("X", "Y"))) +print(f(("X", "Z"))) +print(f(("X", "A"))) + +[out] +THERE +THERE +OTHER + [case testCustomMappingAndSequenceObjects_python3_10] def f(x: object) -> None: match x: