Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 250 additions & 1 deletion tests/repo_utils/test_check_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import inspect
import os
import sys
import tempfile
import textwrap
import unittest


git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
sys.path.append(os.path.join(git_repo_path, "utils"))

from check_docstrings import get_default_description, replace_default_in_arg_description # noqa: E402
from check_docstrings import ( # noqa: E402
_build_ast_indexes,
_find_typed_dict_classes,
_get_auto_docstring_names,
get_default_description,
has_auto_docstring_decorator,
replace_default_in_arg_description,
)


class CheckDostringsTested(unittest.TestCase):
Expand Down Expand Up @@ -96,3 +106,242 @@ def _fake_function(a, b: int, c=1, d: float = 2.0, e: str = "blob"):
assert get_default_description(params["c"]) == "`<fill_type>`, *optional*, defaults to 1"
assert get_default_description(params["d"]) == "`float`, *optional*, defaults to 2.0"
assert get_default_description(params["e"]) == '`str`, *optional*, defaults to `"blob"`'


class TestGetAutoDocstringNames(unittest.TestCase):
"""Tests for _get_auto_docstring_names and has_auto_docstring_decorator."""

def setUp(self):
self.cache = {}

def _write_temp(self, source):
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
f.write(source)
self.addCleanup(os.unlink, f.name)
return f.name

def test_detects_simple_decorator(self):
"""Test that a class decorated with @auto_docstring is detected."""
path = self._write_temp(
textwrap.dedent("""\
from transformers import auto_docstring

@auto_docstring
class Foo:
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, {"Foo"})

def test_detects_decorator_with_call(self):
"""Test that a class decorated with @auto_docstring(args) (called form) is detected."""
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring(custom_args='x')
class Bar:
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, {"Bar"})

def test_ignores_other_decorators(self):
"""Test that classes with non-auto_docstring decorators are not detected."""
path = self._write_temp(
textwrap.dedent("""\
@dataclass
class Baz:
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, set())

def test_multiple_classes(self):
"""Test that only decorated classes and functions are returned when multiple definitions exist."""
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring
class A:
pass

class B:
pass

@auto_docstring()
def func_c():
pass
""")
)
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, {"A", "func_c"})

def test_caching(self):
"""Test that repeated calls for the same file return the cached (identical) result object."""
path = self._write_temp(
textwrap.dedent("""\
@auto_docstring
class X:
pass
""")
)
result1 = _get_auto_docstring_names(path, cache=self.cache)
result2 = _get_auto_docstring_names(path, cache=self.cache)
self.assertIs(result1, result2)

def test_syntax_error_returns_empty(self):
"""Test that a file with a syntax error returns an empty set instead of raising."""
path = self._write_temp("def broken(\n")
names = _get_auto_docstring_names(path, cache=self.cache)
self.assertEqual(names, set())

def test_has_auto_docstring_decorator_uses_cache(self):
"""Test that has_auto_docstring_decorator looks up names from the pre-populated cache."""
from unittest.mock import patch

path = self._write_temp(
textwrap.dedent("""\
@auto_docstring
class Cached:
pass
""")
)
self.cache[path] = {"Cached"}

# Create classes whose __name__ matches/doesn't match the cache
Cached = type("Cached", (), {})
Other = type("Other", (), {})

with patch.object(inspect, "getfile", return_value=path):
self.assertTrue(has_auto_docstring_decorator(Cached, cache=self.cache))
self.assertFalse(has_auto_docstring_decorator(Other, cache=self.cache))


class TestBuildAstIndexes(unittest.TestCase):
"""Tests for _build_ast_indexes with pre-parsed tree."""

def test_finds_decorated_items(self):
"""Test that _build_ast_indexes finds a decorated class and extracts its __init__ args."""
source = textwrap.dedent("""\
@auto_docstring
class MyModel:
def __init__(self, hidden_size=768):
self.hidden_size = hidden_size
""")
items = _build_ast_indexes(source)
self.assertEqual(len(items), 1)
self.assertEqual(items[0].name, "MyModel")
self.assertEqual(items[0].kind, "class")
self.assertIn("hidden_size", items[0].args)

def test_shared_tree(self):
"""Test that passing a pre-parsed AST tree produces the same results as letting the function parse internally."""
source = textwrap.dedent("""\
@auto_docstring
class A:
pass
""")
tree = ast.parse(source)
items_with_tree = _build_ast_indexes(source, tree=tree)
items_without = _build_ast_indexes(source)
self.assertEqual(len(items_with_tree), len(items_without))
self.assertEqual(items_with_tree[0].name, items_without[0].name)

def test_no_decorated_items(self):
"""Test that a class without the auto_docstring decorator is not indexed."""
source = textwrap.dedent("""\
class Plain:
pass
""")
items = _build_ast_indexes(source)
self.assertEqual(items, [])

def test_function_decorated(self):
"""Test that a decorated function is indexed with its arguments."""
source = textwrap.dedent("""\
@auto_docstring
def my_func(x, y=10):
pass
""")
items = _build_ast_indexes(source)
self.assertEqual(len(items), 1)
self.assertEqual(items[0].name, "my_func")
self.assertEqual(items[0].kind, "function")
self.assertIn("x", items[0].args)
self.assertIn("y", items[0].args)

def test_custom_args_from_variable(self):
"""Test that custom_args passed as a module-level variable are resolved to their string value."""
source = textwrap.dedent("""\
MY_ARGS = "custom param docs"

@auto_docstring(custom_args=MY_ARGS)
class WithCustom:
def __init__(self):
pass
""")
items = _build_ast_indexes(source)
self.assertEqual(len(items), 1)
self.assertEqual(items[0].custom_args_text, "custom param docs")


class TestFindTypedDictClasses(unittest.TestCase):
"""Tests for _find_typed_dict_classes with pre-parsed tree."""

def test_finds_typed_dict(self):
"""Test that a TypedDict subclass is found and its public fields are extracted."""
source = textwrap.dedent("""\
from typing import TypedDict

class MyKwargs(TypedDict):
field_a: str
field_b: int
""")
result = _find_typed_dict_classes(source)
self.assertEqual(len(result), 1)
self.assertEqual(result[0]["name"], "MyKwargs")
self.assertIn("field_a", result[0]["all_fields"])
self.assertIn("field_b", result[0]["all_fields"])

def test_shared_tree(self):
"""Test that passing a pre-parsed AST tree produces the same results as internal parsing."""
source = textwrap.dedent("""\
class MyKwargs(TypedDict):
x: int
""")
tree = ast.parse(source)
r1 = _find_typed_dict_classes(source, tree=tree)
r2 = _find_typed_dict_classes(source)
self.assertEqual(len(r1), len(r2))
self.assertEqual(r1[0]["name"], r2[0]["name"])

def test_skips_standard_kwargs(self):
"""Test that well-known kwargs TypedDicts (e.g. TextKwargs) are excluded from results."""
source = textwrap.dedent("""\
class TextKwargs(TypedDict):
field: str
""")
result = _find_typed_dict_classes(source)
self.assertEqual(result, [])

def test_no_typed_dicts(self):
"""Test that source with no TypedDict subclasses returns an empty list."""
source = textwrap.dedent("""\
class Regular:
pass
""")
result = _find_typed_dict_classes(source)
self.assertEqual(result, [])

def test_skips_private_fields(self):
"""Test that fields starting with an underscore are excluded from the extracted TypedDict fields."""
source = textwrap.dedent("""\
class MyKwargs(TypedDict):
public: int
_private: str
""")
result = _find_typed_dict_classes(source)
self.assertEqual(len(result), 1)
self.assertIn("public", result[0]["all_fields"])
self.assertNotIn("_private", result[0]["all_fields"])
Loading
Loading