diff --git a/kernels/src/kernels/_versions.py b/kernels/src/kernels/_versions.py index 617dabf0..4de8db74 100644 --- a/kernels/src/kernels/_versions.py +++ b/kernels/src/kernels/_versions.py @@ -1,9 +1,12 @@ +import logging import warnings from huggingface_hub.hf_api import GitRefInfo from packaging.specifiers import SpecifierSet from packaging.version import InvalidVersion, Version +logger = logging.getLogger(__name__) + def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]: """Get kernel versions that are available in the repository.""" @@ -55,6 +58,16 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: int | str) -> GitRef raise ValueError( f"Version {version_spec} not found, available versions: {', '.join(sorted(str(v) for v in versions.keys()))}" ) + + latest_version = max(versions.keys()) + if version_spec < latest_version: + logger.warning( + "You are using version %d of '%s', but version %d is available.", + version_spec, + repo_id, + latest_version, + ) + return ref else: warnings.warn( diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index 7486134c..dcae8811 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -1,3 +1,5 @@ +import logging + import pytest import torch import torch.nn.functional as F @@ -161,6 +163,22 @@ def test_version(): kernel = get_kernel("kernels-test/versions", version=0) +def test_version_outdated_warning(caplog): + with caplog.at_level(logging.WARNING, logger="kernels._versions"): + kernel = get_kernel("kernels-test/versions", version=1) + assert kernel.version() == "1" + assert ( + "You are using version 1 of 'kernels-test/versions', but version 2 is available." + in caplog.text + ) + + caplog.clear() + with caplog.at_level(logging.WARNING, logger="kernels._versions"): + kernel = get_kernel("kernels-test/versions", version=2) + assert kernel.version() == "2" + assert "but version" not in caplog.text + + @pytest.mark.cuda_only def test_universal_kernel(universal_kernel): torch.manual_seed(0)