2323"""
2424
2525import argparse
26+ import ast
2627import shutil
2728import sys
2829from pathlib import Path
2930
30- from lib_updater import PatchSpec , UtMethod , apply_patches
31+ from lib_updater import (
32+ COMMENT ,
33+ PatchSpec ,
34+ UtMethod ,
35+ apply_patches ,
36+ )
3137
3238
3339def parse_args ():
@@ -61,15 +67,18 @@ def __str__(self):
6167class TestResult :
6268 tests_result : str = ""
6369 tests = []
70+ unexpected_successes = [] # Tests that passed but were marked as expectedFailure
6471 stdout = ""
6572
6673 def __str__ (self ):
67- return f"TestResult(tests_result={ self .tests_result } ,tests={ len (self .tests )} )"
74+ return f"TestResult(tests_result={ self .tests_result } ,tests={ len (self .tests )} ,unexpected_successes= { len ( self . unexpected_successes ) } )"
6875
6976
7077def parse_results (result ):
7178 lines = result .stdout .splitlines ()
7279 test_results = TestResult ()
80+ test_results .tests = []
81+ test_results .unexpected_successes = []
7382 test_results .stdout = result .stdout
7483 in_test_results = False
7584 for line in lines :
@@ -107,6 +116,19 @@ def parse_results(result):
107116 res = line .split ("== Tests result: " )[1 ]
108117 res = res .split (" " )[0 ]
109118 test_results .tests_result = res
119+ # Parse: "UNEXPECTED SUCCESS: test_name (path)"
120+ elif line .startswith ("UNEXPECTED SUCCESS: " ):
121+ rest = line [len ("UNEXPECTED SUCCESS: " ) :]
122+ # Format: "test_name (path)"
123+ first_space = rest .find (" " )
124+ if first_space > 0 :
125+ test = Test ()
126+ test .name = rest [:first_space ]
127+ path_part = rest [first_space :].strip ()
128+ if path_part .startswith ("(" ) and path_part .endswith (")" ):
129+ test .path = path_part [1 :- 1 ]
130+ test .result = "unexpected_success"
131+ test_results .unexpected_successes .append (test )
110132 return test_results
111133
112134
@@ -117,6 +139,47 @@ def path_to_test(path) -> list[str]:
117139 return parts [- 2 :] # Get class name and method name
118140
119141
142+ def remove_expected_failures (contents : str , tests_to_remove : set [tuple [str , str ]]) -> str :
143+ """Remove @unittest.expectedFailure decorators from tests that now pass."""
144+ if not tests_to_remove :
145+ return contents
146+
147+ tree = ast .parse (contents )
148+ lines = contents .splitlines ()
149+ lines_to_remove = set ()
150+
151+ for node in ast .walk (tree ):
152+ if not isinstance (node , ast .ClassDef ):
153+ continue
154+ class_name = node .name
155+ for item in node .body :
156+ if not isinstance (item , (ast .FunctionDef , ast .AsyncFunctionDef )):
157+ continue
158+ method_name = item .name
159+ if (class_name , method_name ) not in tests_to_remove :
160+ continue
161+
162+ # Find and mark expectedFailure decorators for removal
163+ for dec in item .decorator_list :
164+ dec_line = dec .lineno - 1 # 0-indexed
165+ line_content = lines [dec_line ]
166+
167+ # Check if it's @unittest.expectedFailure with TODO: RUSTPYTHON
168+ if "expectedFailure" in line_content and COMMENT in line_content :
169+ lines_to_remove .add (dec_line )
170+ # Also check the line before for a standalone TODO comment
171+ if dec_line > 0 :
172+ prev_line = lines [dec_line - 1 ].strip ()
173+ if prev_line .startswith ("#" ) and COMMENT in prev_line :
174+ lines_to_remove .add (dec_line - 1 )
175+
176+ # Remove lines in reverse order to maintain line numbers
177+ for line_idx in sorted (lines_to_remove , reverse = True ):
178+ del lines [line_idx ]
179+
180+ return "\n " .join (lines ) + "\n " if lines else ""
181+
182+
120183def build_patches (test_parts_set : set [tuple [str , str ]]) -> dict :
121184 """Convert failing tests to lib_updater patch format."""
122185 patches = {}
@@ -190,20 +253,38 @@ def run_test(test_name):
190253 f = test_path .read_text (encoding = "utf-8" )
191254
192255 # Collect failing tests (with deduplication for subtests)
193- seen_tests = set () # Track (class_name, method_name) to avoid duplicates
256+ failing_tests = set () # Track (class_name, method_name) to avoid duplicates
194257 for test in tests .tests :
195258 if test .result == "fail" or test .result == "error" :
196259 test_parts = path_to_test (test .path )
197260 if len (test_parts ) == 2 :
198261 test_key = tuple (test_parts )
199- if test_key not in seen_tests :
200- seen_tests .add (test_key )
201- print (f"Marking test: { test_parts [0 ]} .{ test_parts [1 ]} " )
202-
203- # Apply patches using lib_updater
204- if seen_tests :
205- patches = build_patches (seen_tests )
262+ if test_key not in failing_tests :
263+ failing_tests .add (test_key )
264+ print (f"Marking as failing: { test_parts [0 ]} .{ test_parts [1 ]} " )
265+
266+ # Collect unexpected successes (tests that now pass but have expectedFailure)
267+ unexpected_successes = set ()
268+ for test in tests .unexpected_successes :
269+ test_parts = path_to_test (test .path )
270+ if len (test_parts ) == 2 :
271+ test_key = tuple (test_parts )
272+ if test_key not in unexpected_successes :
273+ unexpected_successes .add (test_key )
274+ print (f"Removing expectedFailure: { test_parts [0 ]} .{ test_parts [1 ]} " )
275+
276+ # Remove expectedFailure from tests that now pass
277+ if unexpected_successes :
278+ f = remove_expected_failures (f , unexpected_successes )
279+
280+ # Apply patches for failing tests
281+ if failing_tests :
282+ patches = build_patches (failing_tests )
206283 f = apply_patches (f , patches )
284+
285+ # Write changes if any modifications were made
286+ if failing_tests or unexpected_successes :
207287 test_path .write_text (f , encoding = "utf-8" )
208288
209- print (f"Modified { len (seen_tests )} tests" )
289+ print (f"Added expectedFailure to { len (failing_tests )} tests" )
290+ print (f"Removed expectedFailure from { len (unexpected_successes )} tests" )
0 commit comments