Skip to content

Commit 36d9705

Browse files
committed
fix custom_text_test_runner
1 parent 313ea72 commit 36d9705

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

extra_tests/custom_text_test_runner.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@
3838
from unittest.runner import registerResult, result
3939

4040

41+
def _get_method_dict(test):
42+
"""Get the __dict__ of the underlying function for a test method.
43+
44+
Works for both bound methods (__func__.__dict__) and plain functions.
45+
"""
46+
method = getattr(test, test._testMethodName)
47+
func = getattr(method, "__func__", method)
48+
return func.__dict__
49+
50+
4151
class TablePrinter(object):
4252
# Modified from https://github.com/agramian/table-printer, same license as above
4353
"Print a list of dicts as a table"
@@ -325,9 +335,7 @@ def startTest(self, test):
325335
self.stream.writeln("TEST SUITE: %s" % self.suite)
326336
self.stream.writeln("Description: %s" % self.getSuiteDescription(test))
327337
try:
328-
name_override = getattr(test, test._testMethodName).__func__.__dict__[
329-
"test_case_name"
330-
]
338+
name_override = _get_method_dict(test)["test_case_name"]
331339
except:
332340
name_override = None
333341
self.case = name_override if name_override else self.case
@@ -345,7 +353,11 @@ def startTest(self, test):
345353
self.results["suites"][self.suite_number] = {
346354
"name": self.suite,
347355
"class": test.__class__.__name__,
348-
"module": re.compile(".* \((.*)\)").match(str(test)).group(1),
356+
"module": (
357+
m.group(1)
358+
if (m := re.compile(r".* \((.*)\)").match(str(test)))
359+
else str(test)
360+
),
349361
"description": self.getSuiteDescription(test),
350362
"cases": {},
351363
"used_case_names": {},
@@ -380,34 +392,22 @@ def startTest(self, test):
380392
if "test_type" in getattr(
381393
test, test._testMethodName
382394
).__func__.__dict__ and set([s.lower() for s in self.test_types]) == set(
383-
[
384-
s.lower()
385-
for s in getattr(test, test._testMethodName).__func__.__dict__[
386-
"test_type"
387-
]
388-
]
395+
[s.lower() for s in _get_method_dict(test)["test_type"]]
389396
):
390397
pass
391398
else:
392-
getattr(test, test._testMethodName).__func__.__dict__[
393-
"__unittest_skip_why__"
394-
] = 'Test run specified to only run tests of type "%s"' % ",".join(
395-
self.test_types
399+
_get_method_dict(test)["__unittest_skip_why__"] = (
400+
'Test run specified to only run tests of type "%s"'
401+
% ",".join(self.test_types)
396402
)
397-
getattr(test, test._testMethodName).__func__.__dict__[
398-
"__unittest_skip__"
399-
] = True
400-
if "skip_device" in getattr(test, test._testMethodName).__func__.__dict__:
401-
for device in getattr(test, test._testMethodName).__func__.__dict__[
402-
"skip_device"
403-
]:
403+
_get_method_dict(test)["__unittest_skip__"] = True
404+
if "skip_device" in _get_method_dict(test):
405+
for device in _get_method_dict(test)["skip_device"]:
404406
if self.config and device.lower() in self.config["device_name"].lower():
405-
getattr(test, test._testMethodName).__func__.__dict__[
406-
"__unittest_skip_why__"
407-
] = "Test is marked to be skipped on %s" % device
408-
getattr(test, test._testMethodName).__func__.__dict__[
409-
"__unittest_skip__"
410-
] = True
407+
_get_method_dict(test)["__unittest_skip_why__"] = (
408+
"Test is marked to be skipped on %s" % device
409+
)
410+
_get_method_dict(test)["__unittest_skip__"] = True
411411
break
412412

413413
def stopTest(self, test):

0 commit comments

Comments
 (0)