Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions tests/test___version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@


class VersionAPI(APIBase):
"""Equivalence check for ``torch.__version__`` / ``paddle.__version__``.

``torch.torch_version.TorchVersion`` is a ``str`` subclass whose
comparison operators dispatch through ``packaging.version.Version``,
falling back to plain string comparison on invalid versions. The Paddle
equivalent must mirror that surface. Each framework parses ``Version``
with its own copy of ``packaging`` (torch vendors it), so this test
compares only against ``str`` / ``tuple`` inputs -- both implementations
convert those through their own ``Version`` internally.
"""

def compare(
self,
name,
Expand All @@ -31,8 +42,40 @@ def compare(
rtol=1.0e-6,
atol=0.0,
):
# torch return: torch.torch_version.TorchVersion
assert type(str(pytorch_result)) == type(paddle_result)
assert isinstance(
pytorch_result, str
), f"pytorch result should be a str (subclass), got {type(pytorch_result)}"
assert isinstance(
paddle_result, str
), f"paddle result should be a str (subclass), got {type(paddle_result)}"

assert str(pytorch_result), "pytorch version string is empty"
assert str(paddle_result), "paddle version string is empty"

assert isinstance(pytorch_result.split("."), list)
assert isinstance(paddle_result.split("."), list)
assert pytorch_result.startswith(str(pytorch_result)[0])
assert paddle_result.startswith(str(paddle_result)[0])

assert (pytorch_result == str(pytorch_result)) is True
assert (paddle_result == str(paddle_result)) is True

assert (pytorch_result > "0.0.1") is True
assert (paddle_result > "0.0.1") is True
assert (pytorch_result > (0, 0, 1)) is True
assert (paddle_result > (0, 0, 1)) is True
assert (pytorch_result >= "0.0.1") is True
assert (paddle_result >= "0.0.1") is True
assert (pytorch_result < "999.0") is True
assert (paddle_result < "999.0") is True

assert (pytorch_result == "parrot") == (paddle_result == "parrot")
assert (pytorch_result != "parrot") == (paddle_result != "parrot")

# __eq__ is installed via setattr on both, which does not implicitly
# reset __hash__ to None -- guard against that regression.
hash(pytorch_result)
hash(paddle_result)


obj = VersionAPI("torch.__version__")
Expand Down