Skip to content

Commit baf7455

Browse files
committed
[update_lib] Fix async func auto-mark
1 parent 568f24c commit baf7455

File tree

3 files changed

+239
-19
lines changed

3 files changed

+239
-19
lines changed

scripts/update_lib/cmd_auto_mark.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -71,53 +71,89 @@ def run_test(test_name: str, skip_build: bool = False) -> TestResult:
7171
return parse_results(result)
7272

7373

74+
def _try_parse_test_info(test_info: str) -> tuple[str, str] | None:
75+
"""Try to extract (name, path) from 'test_name (path)' or 'test_name (path) [subtest]'."""
76+
first_space = test_info.find(" ")
77+
if first_space > 0:
78+
name = test_info[:first_space]
79+
rest = test_info[first_space:].strip()
80+
if rest.startswith("("):
81+
end_paren = rest.find(")")
82+
if end_paren > 0:
83+
return name, rest[1:end_paren]
84+
return None
85+
86+
7487
def parse_results(result: subprocess.CompletedProcess) -> TestResult:
7588
"""Parse subprocess result into TestResult."""
7689
lines = result.stdout.splitlines()
7790
test_results = TestResult()
7891
test_results.stdout = result.stdout
7992
in_test_results = False
93+
# For multiline format: "test_name (path)\ndocstring ... RESULT"
94+
pending_test_info = None
8095

8196
for line in lines:
8297
if re.search(r"Run \d+ tests? sequentially", line):
8398
in_test_results = True
84-
elif line.startswith("-----------"):
99+
elif "== Tests result: " in line:
85100
in_test_results = False
86101

87102
if in_test_results and " ... " in line:
88-
line = line.strip()
103+
stripped = line.strip()
89104
# Skip lines that don't look like test results
90-
if line.startswith("tests") or line.startswith("["):
105+
if stripped.startswith("tests") or stripped.startswith("["):
106+
pending_test_info = None
91107
continue
92108
# Parse: "test_name (path) [subtest] ... RESULT"
93-
parts = line.split(" ... ")
109+
parts = stripped.split(" ... ")
94110
if len(parts) >= 2:
95111
test_info = parts[0]
96112
result_str = parts[-1].lower()
97113
# Only process FAIL or ERROR
98114
if result_str not in ("fail", "error"):
115+
pending_test_info = None
99116
continue
100-
# Extract test name (first word)
101-
first_space = test_info.find(" ")
102-
if first_space > 0:
117+
# Try parsing from this line (single-line format)
118+
parsed = _try_parse_test_info(test_info)
119+
if not parsed and pending_test_info:
120+
# Multiline format: previous line had test_name (path)
121+
parsed = _try_parse_test_info(pending_test_info)
122+
if parsed:
103123
test = Test()
104-
test.name = test_info[:first_space]
105-
# Extract path from (path)
106-
rest = test_info[first_space:].strip()
107-
if rest.startswith("("):
108-
end_paren = rest.find(")")
109-
if end_paren > 0:
110-
test.path = rest[1:end_paren]
111-
test.result = result_str
112-
test_results.tests.append(test)
124+
test.name, test.path = parsed
125+
test.result = result_str
126+
test_results.tests.append(test)
127+
pending_test_info = None
128+
129+
elif in_test_results:
130+
# Track test info for multiline format:
131+
# test_name (path)
132+
# docstring ... RESULT
133+
stripped = line.strip()
134+
if (
135+
stripped
136+
and "(" in stripped
137+
and stripped.endswith(")")
138+
and ":" not in stripped.split("(")[0]
139+
):
140+
pending_test_info = stripped
141+
else:
142+
pending_test_info = None
143+
144+
# Also check for Tests result on non-" ... " lines
145+
if "== Tests result: " in line:
146+
res = line.split("== Tests result: ")[1]
147+
res = res.split(" ")[0]
148+
test_results.tests_result = res
113149

114150
elif "== Tests result: " in line:
115151
res = line.split("== Tests result: ")[1]
116152
res = res.split(" ")[0]
117153
test_results.tests_result = res
118154

119155
# Parse: "UNEXPECTED SUCCESS: test_name (path)"
120-
elif line.startswith("UNEXPECTED SUCCESS: "):
156+
if line.startswith("UNEXPECTED SUCCESS: "):
121157
rest = line[len("UNEXPECTED SUCCESS: ") :]
122158
# Format: "test_name (path)"
123159
first_space = rest.find(" ")
@@ -232,13 +268,16 @@ def build_patches(
232268

233269

234270
def _is_super_call_only(func_node: ast.FunctionDef | ast.AsyncFunctionDef) -> bool:
235-
"""Check if the method body is just 'return super().method_name()'."""
271+
"""Check if the method body is just 'return super().method_name()' or 'return await super().method_name()'."""
236272
if len(func_node.body) != 1:
237273
return False
238274
stmt = func_node.body[0]
239275
if not isinstance(stmt, ast.Return) or stmt.value is None:
240276
return False
241277
call = stmt.value
278+
# Unwrap await for async methods
279+
if isinstance(call, ast.Await):
280+
call = call.value
242281
if not isinstance(call, ast.Call):
243282
return False
244283
if not isinstance(call.func, ast.Attribute):

scripts/update_lib/patch_spec.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,14 @@ def _iter_patch_lines(
247247

248248
# Build cache of all classes (for Phase 2 to find classes without methods)
249249
cache = {}
250+
# Build set of async method names (for Phase 2 to generate correct override)
251+
async_methods: set[str] = set()
250252
for node in tree.body:
251253
if isinstance(node, ast.ClassDef):
252254
cache[node.name] = node.end_lineno
255+
for item in node.body:
256+
if isinstance(item, ast.AsyncFunctionDef):
257+
async_methods.add(item.name)
253258

254259
# Phase 1: Iterate and mark existing tests
255260
for cls_node, fn_node in iter_tests(tree):
@@ -274,7 +279,15 @@ def _iter_patch_lines(
274279

275280
for test_name, specs in tests.items():
276281
decorators = "\n".join(spec.as_decorator() for spec in specs)
277-
patch_lines = f"""
282+
is_async = test_name in async_methods
283+
if is_async:
284+
patch_lines = f"""
285+
{decorators}
286+
async def {test_name}(self):
287+
{DEFAULT_INDENT}return await super().{test_name}()
288+
""".rstrip()
289+
else:
290+
patch_lines = f"""
278291
{decorators}
279292
def {test_name}(self):
280293
{DEFAULT_INDENT}return super().{test_name}()

scripts/update_lib/tests/test_auto_mark.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,78 @@ def test_parse_error_message(self):
118118
self.assertEqual(len(result.tests), 1)
119119
self.assertEqual(result.tests[0].error_message, "AssertionError: 1 != 2")
120120

121+
def test_parse_directory_test_multiple_submodules(self):
122+
"""Test parsing directory test output with multiple submodules.
123+
124+
When running a directory test (e.g., test_asyncio), the output contains
125+
multiple submodules separated by '------' lines. Failures in submodules
126+
after the first one must still be detected.
127+
"""
128+
stdout = """\
129+
Run 3 tests sequentially
130+
0:00:00 [ 1/3] test_asyncio.test_buffered_proto
131+
test_ok (test.test_asyncio.test_buffered_proto.TestProto.test_ok) ... ok
132+
133+
----------------------------------------------------------------------
134+
Ran 1 tests in 0.1s
135+
136+
OK
137+
138+
0:00:01 [ 2/3] test_asyncio.test_events
139+
test_create (test.test_asyncio.test_events.TestEvents.test_create) ... FAIL
140+
141+
----------------------------------------------------------------------
142+
Ran 1 tests in 0.2s
143+
144+
FAILED (failures=1)
145+
146+
0:00:02 [ 3/3] test_asyncio.test_tasks
147+
test_gather (test.test_asyncio.test_tasks.TestTasks.test_gather) ... ERROR
148+
149+
----------------------------------------------------------------------
150+
Ran 1 tests in 0.3s
151+
152+
FAILED (errors=1)
153+
154+
== Tests result: FAILURE ==
155+
"""
156+
result = parse_results(self._make_result(stdout))
157+
self.assertEqual(len(result.tests), 2)
158+
names = {t.name for t in result.tests}
159+
self.assertIn("test_create", names)
160+
self.assertIn("test_gather", names)
161+
# Verify results
162+
test_create = next(t for t in result.tests if t.name == "test_create")
163+
test_gather = next(t for t in result.tests if t.name == "test_gather")
164+
self.assertEqual(test_create.result, "fail")
165+
self.assertEqual(test_gather.result, "error")
166+
self.assertEqual(result.tests_result, "FAILURE")
167+
168+
def test_parse_multiline_test_with_docstring(self):
169+
"""Test parsing tests where docstring appears on a separate line.
170+
171+
Some tests have docstrings that cause the output to span two lines:
172+
test_name (path)
173+
docstring ... ERROR
174+
"""
175+
stdout = """\
176+
Run 3 tests sequentially
177+
test_ok (test.test_example.TestClass.test_ok) ... ok
178+
test_with_doc (test.test_example.TestClass.test_with_doc)
179+
Test that something works ... ERROR
180+
test_normal_fail (test.test_example.TestClass.test_normal_fail) ... FAIL
181+
"""
182+
result = parse_results(self._make_result(stdout))
183+
self.assertEqual(len(result.tests), 2)
184+
names = {t.name for t in result.tests}
185+
self.assertIn("test_with_doc", names)
186+
self.assertIn("test_normal_fail", names)
187+
test_doc = next(t for t in result.tests if t.name == "test_with_doc")
188+
self.assertEqual(
189+
test_doc.path, "test.test_example.TestClass.test_with_doc"
190+
)
191+
self.assertEqual(test_doc.result, "error")
192+
121193
def test_parse_multiple_error_messages(self):
122194
"""Test parsing multiple error messages."""
123195
stdout = """
@@ -644,6 +716,102 @@ def test_one(self):
644716
method = self._parse_method(code)
645717
self.assertFalse(_is_super_call_only(method))
646718

719+
def test_async_await_super_call(self):
720+
"""Test async method that awaits super().same_name()."""
721+
code = """
722+
class Foo:
723+
async def test_one(self):
724+
return await super().test_one()
725+
"""
726+
method = self._parse_method(code)
727+
self.assertTrue(_is_super_call_only(method))
728+
729+
def test_async_await_mismatched_super_call(self):
730+
"""Test async method that awaits super().different_name()."""
731+
code = """
732+
class Foo:
733+
async def test_one(self):
734+
return await super().test_two()
735+
"""
736+
method = self._parse_method(code)
737+
self.assertFalse(_is_super_call_only(method))
738+
739+
def test_async_without_await(self):
740+
"""Test async method that calls super() without await (sync super call in async method)."""
741+
code = """
742+
class Foo:
743+
async def test_one(self):
744+
return super().test_one()
745+
"""
746+
method = self._parse_method(code)
747+
self.assertTrue(_is_super_call_only(method))
748+
749+
750+
class TestAsyncInheritedOverride(unittest.TestCase):
751+
"""Tests for async inherited method override generation."""
752+
753+
def test_inherited_async_method_generates_async_override(self):
754+
"""Test that inherited async methods get async def + await override."""
755+
code = """import unittest
756+
757+
class BaseTest:
758+
async def test_async_one(self):
759+
pass
760+
761+
class TestChild(BaseTest, unittest.TestCase):
762+
pass
763+
"""
764+
failing = {("TestChild", "test_async_one")}
765+
result = apply_test_changes(code, failing, set())
766+
767+
self.assertIn("async def test_async_one(self):", result)
768+
self.assertIn("return await super().test_async_one()", result)
769+
self.assertIn("@unittest.expectedFailure", result)
770+
771+
def test_inherited_sync_method_generates_sync_override(self):
772+
"""Test that inherited sync methods get sync def override."""
773+
code = """import unittest
774+
775+
class BaseTest:
776+
def test_sync_one(self):
777+
pass
778+
779+
class TestChild(BaseTest, unittest.TestCase):
780+
pass
781+
"""
782+
failing = {("TestChild", "test_sync_one")}
783+
result = apply_test_changes(code, failing, set())
784+
785+
self.assertIn("def test_sync_one(self):", result)
786+
self.assertIn("return super().test_sync_one()", result)
787+
self.assertNotIn("async def test_sync_one", result)
788+
self.assertNotIn("await", result)
789+
790+
def test_remove_async_super_call_override(self):
791+
"""Test removing async super call override on unexpected success."""
792+
code = f"""import unittest
793+
794+
class BaseTest:
795+
async def test_async_one(self):
796+
pass
797+
798+
class TestChild(BaseTest, unittest.TestCase):
799+
# {COMMENT}
800+
@unittest.expectedFailure
801+
async def test_async_one(self):
802+
return await super().test_async_one()
803+
"""
804+
successes = {("TestChild", "test_async_one")}
805+
result = apply_test_changes(code, set(), successes)
806+
807+
# The override in TestChild should be removed; base class method remains
808+
self.assertNotIn("return await super().test_async_one()", result)
809+
self.assertNotIn("@unittest.expectedFailure", result)
810+
self.assertIn("class TestChild", result)
811+
# Base class method should still be present
812+
self.assertIn("class BaseTest", result)
813+
self.assertIn("async def test_async_one(self):", result)
814+
647815

648816
if __name__ == "__main__":
649817
unittest.main()

0 commit comments

Comments
 (0)