diff --git a/tests/test___version__.py b/tests/test___version__.py index e87716de7..1f6fbda31 100644 --- a/tests/test___version__.py +++ b/tests/test___version__.py @@ -19,6 +19,17 @@ class VersionAPI(APIBase): + """Equivalence check for ``torch.__version__`` / ``paddle.__version__``. + + ``torch.torch_version.TorchVersion`` is a ``str`` subclass whose + comparison operators dispatch through ``packaging.version.Version``, + falling back to plain string comparison on invalid versions. The Paddle + equivalent must mirror that surface. Each framework parses ``Version`` + with its own copy of ``packaging`` (torch vendors it), so this test + compares only against ``str`` / ``tuple`` inputs -- both implementations + convert those through their own ``Version`` internally. + """ + def compare( self, name, @@ -31,8 +42,40 @@ def compare( rtol=1.0e-6, atol=0.0, ): - # torch return: torch.torch_version.TorchVersion - assert type(str(pytorch_result)) == type(paddle_result) + assert isinstance( + pytorch_result, str + ), f"pytorch result should be a str (subclass), got {type(pytorch_result)}" + assert isinstance( + paddle_result, str + ), f"paddle result should be a str (subclass), got {type(paddle_result)}" + + assert str(pytorch_result), "pytorch version string is empty" + assert str(paddle_result), "paddle version string is empty" + + assert isinstance(pytorch_result.split("."), list) + assert isinstance(paddle_result.split("."), list) + assert pytorch_result.startswith(str(pytorch_result)[0]) + assert paddle_result.startswith(str(paddle_result)[0]) + + assert (pytorch_result == str(pytorch_result)) is True + assert (paddle_result == str(paddle_result)) is True + + assert (pytorch_result > "0.0.1") is True + assert (paddle_result > "0.0.1") is True + assert (pytorch_result > (0, 0, 1)) is True + assert (paddle_result > (0, 0, 1)) is True + assert (pytorch_result >= "0.0.1") is True + assert (paddle_result >= "0.0.1") is True + assert (pytorch_result < "999.0") is True + assert (paddle_result < "999.0") is True + + assert (pytorch_result == "parrot") == (paddle_result == "parrot") + assert (pytorch_result != "parrot") == (paddle_result != "parrot") + + # __eq__ is installed via setattr on both, which does not implicitly + # reset __hash__ to None -- guard against that regression. + hash(pytorch_result) + hash(paddle_result) obj = VersionAPI("torch.__version__")