Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/cli-versions.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# kernels versions

Use `kernels versions` to list all available versions of a kernel on the Hub.
Use `kernels versions` to list all available versions of a kernel on the Hub
and marks compatible versions.

## Usage

Expand All @@ -19,7 +20,7 @@ kernels versions kernels-community/activation
## Example Output

```text
Version 1: torch210-cu128-x86_64-windows, torch210-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch210-metal-aarch64-darwin ✅, torch27-cxx11-cu118-x86_64-linux, torch27-cxx11-cu126-x86_64-linux, torch27-cxx11-cu128-aarch64-linux, torch27-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu126-x86_64-linux, torch28-cxx11-cu128-aarch64-linux, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu129-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, torch29-cxx11-cu126-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu128-x86_64-linux, torch29-cxx11-cu130-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch29-metal-aarch64-darwin
Version 1: torch210-metal-aarch64-darwin, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch27-cxx11-cu118-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-aarch64-linux, torch27-cxx11-cu126-x86_64-linux, ✅ torch29-cxx11-cu126-x86_64-linux (compatible), torch27-cxx11-cu128-x86_64-linux, torch210-cxx11-cu126-x86_64-linux, torch29-metal-aarch64-darwin, torch27-cxx11-cu128-aarch64-linux, torch210-cu128-x86_64-windows, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, torch29-cxx11-cu128-x86_64-linux (preferred), torch28-cxx11-cu129-x86_64-linux
```

## Use Cases
Expand Down
107 changes: 91 additions & 16 deletions kernels/src/kernels/backends.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import ctypes
import ctypes.util
import re
import warnings
from dataclasses import dataclass
from typing import Optional, Protocol
from typing import ClassVar, Optional, Protocol

from packaging.version import Version

Expand All @@ -18,98 +19,172 @@ def name(self) -> str:
...

@property
def variant(self) -> str:
def variant_str(self) -> str:
"""
The name of the backend as used in a build variant, e.g. `cu128`
for CUDA 12.8.
"""
...


@dataclass
@dataclass(unsafe_hash=True)
class CANN:
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah lovely!


version: Version

@property
def name(self) -> str:
return "cann"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return f"cann{self.version.major}{self.version.minor}"

@staticmethod
def parse(s: str) -> "CANN":
m = CANN._VARIANT_REGEX.fullmatch(s)
if not m:
raise ValueError(f"Invalid CANN variant string: {s!r}")
return CANN(version=Version(f"{m.group(1)}.{m.group(2)}"))


@dataclass
@dataclass(unsafe_hash=True)
class CPU:
@property
def name(self) -> str:
return "cpu"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return "cpu"

@staticmethod
def parse(s: str) -> "CPU":
if s != "cpu":
raise ValueError(f"Invalid CPU variant string: {s!r}")
return CPU()


@dataclass
@dataclass(unsafe_hash=True)
class CUDA:
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)")

version: Version

@property
def name(self) -> str:
return "cuda"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return f"cu{self.version.major}{self.version.minor}"

@staticmethod
def parse(s: str) -> "CUDA":
m = CUDA._VARIANT_REGEX.fullmatch(s)
if not m:
raise ValueError(f"Invalid CUDA variant string: {s!r}")
return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}"))


@dataclass
@dataclass(unsafe_hash=True)
class Metal:
@property
def name(self) -> str:
return "metal"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return "metal"

@staticmethod
def parse(s: str) -> "Metal":
if s != "metal":
raise ValueError(f"Invalid Metal variant string: {s!r}")
return Metal()


@dataclass
@dataclass(unsafe_hash=True)
class Neuron:
@property
def name(self) -> str:
return "neuron"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return "neuron"

@staticmethod
def parse(s: str) -> "Neuron":
if s != "neuron":
raise ValueError(f"Invalid Neuron variant string: {s!r}")
return Neuron()

@dataclass

@dataclass(unsafe_hash=True)
class ROCm:
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)")

version: Version

@property
def name(self) -> str:
return "rocm"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return f"rocm{self.version.major}{self.version.minor}"

@staticmethod
def parse(s: str) -> "ROCm":
m = ROCm._VARIANT_REGEX.fullmatch(s)
if not m:
raise ValueError(f"Invalid ROCm variant string: {s!r}")
return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}"))


@dataclass
@dataclass(unsafe_hash=True)
class XPU:
_VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)")

version: Version

@property
def name(self) -> str:
return "xpu"

@property
def variant(self) -> str:
def variant_str(self) -> str:
return f"xpu{self.version.major}{self.version.minor}"

@staticmethod
def parse(s: str) -> "XPU":
m = XPU._VARIANT_REGEX.fullmatch(s)
if not m:
raise ValueError(f"Invalid XPU variant string: {s!r}")
return XPU(version=Version(f"{m.group(1)}.{m.group(2)}"))


def parse_backend(s: str) -> Backend:
"""Parse a backend variant string (e.g. 'cu128', 'rocm61', 'cpu') into a Backend."""
if s == "cpu":
return CPU.parse(s)
elif s == "metal":
return Metal.parse(s)
elif s == "neuron":
return Neuron.parse(s)
elif s.startswith("cu"):
return CUDA.parse(s)
elif s.startswith("rocm"):
return ROCm.parse(s)
elif s.startswith("xpu"):
return XPU.parse(s)
elif s.startswith("cann"):
return CANN.parse(s)
else:
raise ValueError(f"Unknown backend variant string: {s!r}")


def _backend() -> Backend:
if has_torch:
Expand Down
47 changes: 17 additions & 30 deletions kernels/src/kernels/cli/versions.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,31 @@
from importlib.util import find_spec
from pathlib import Path

from huggingface_hub import HfApi

from kernels._versions import _get_available_versions
from kernels.utils import _build_variants, _get_hf_api
from kernels.variants import BUILD_VARIANT_REGEX
from kernels.utils import _get_hf_api
from kernels.variants import (
get_variants,
resolve_variants,
)


def print_kernel_versions(repo_id: str):
api = _get_hf_api()

if find_spec("torch") is None:
# Do not mark compatible variants when Torch is not available.
compatible_variants = set()
else:
compatible_variants = set(_build_variants(None))

versions = _get_available_versions(repo_id).items()
if not versions:
print(f"Repository does not support kernel versions: {repo_id}")
return

for version, ref in sorted(versions, key=lambda x: x[0]):
variants = get_variants(api, repo_id=repo_id, revision=ref.ref)
resolved = resolve_variants(variants, None)
best = resolved[0] if resolved else None
resolved_set = set(resolved)
print(f"Version {version}: ", end="")
variants = [
f"{variant} ✅" if variant in compatible_variants else f"{variant}"
for variant in _get_build_variants(api, repo_id, ref.ref)
variant_strs = [
(
f"✅ {variant.variant_str} ({'compatible, preferred' if variant == best else 'compatible'})"
if variant in resolved_set
else f"{variant.variant_str}"
)
for variant in variants
]
print(", ".join(variants))


def _get_build_variants(api: "HfApi", repo_id: str, revision: str) -> list[str]:
variants = set()
for filename in api.list_repo_files(repo_id, revision=revision):
path = Path(filename)
if len(path.parts) < 2 or path.parts[0] != "build":
continue

match = BUILD_VARIANT_REGEX.match(path.parts[1])
if match:
variants.add(path.parts[1])
return sorted(variants)
print(", ".join(variant_strs))
Loading
Loading