Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions kernels/src/kernels/_versions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
import warnings

from huggingface_hub.hf_api import GitRefInfo
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version

logger = logging.getLogger(__name__)


def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
Expand Down Expand Up @@ -55,6 +58,16 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: int | str) -> GitRef
raise ValueError(
f"Version {version_spec} not found, available versions: {', '.join(sorted(str(v) for v in versions.keys()))}"
)

latest_version = max(versions.keys())
if version_spec < latest_version:
logger.warning(
"You are using version %d of '%s', but version %d is available.",
version_spec,
repo_id,
latest_version,
)

return ref
else:
warnings.warn(
Expand Down
18 changes: 18 additions & 0 deletions kernels/tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

import pytest
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -161,6 +163,22 @@ def test_version():
kernel = get_kernel("kernels-test/versions", version=0)


def test_version_outdated_warning(caplog):
with caplog.at_level(logging.WARNING, logger="kernels._versions"):
kernel = get_kernel("kernels-test/versions", version=1)
assert kernel.version() == "1"
assert (
"You are using version 1 of 'kernels-test/versions', but version 2 is available."
in caplog.text
)

caplog.clear()
with caplog.at_level(logging.WARNING, logger="kernels._versions"):
kernel = get_kernel("kernels-test/versions", version=2)
assert kernel.version() == "2"
assert "but version" not in caplog.text


@pytest.mark.cuda_only
def test_universal_kernel(universal_kernel):
torch.manual_seed(0)
Expand Down
Loading