From 202a142e3e96fe4b3471a0dd1dea5df1a6350771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 12 Mar 2026 13:54:33 +0000 Subject: [PATCH 1/9] Refactor variant handling and add CUDA fallback Build variants were stringly-typed throughout kernels, with custom parsing and serialization sprinkled everywhere. This change adds proper/strong typing to variants adding a `Variant` class. This also centers parsing/serialization in one place and allows code to easily query various parts of of a variant. This also fundamentally changes how we deal with getting variants from the Hub. Rather than casting a wide net with all possible variants and using allow patterns based on that, we query the hub for variants of a kernel, parse them and can decide if there is an applicable variant ahead of time. If there are multiple applicable variants, we can select the best one (e.g. arch before noarch or recent CUDA version before older versions). --- kernels/src/kernels/backends.py | 88 +++++++- kernels/src/kernels/cli/versions.py | 38 +--- kernels/src/kernels/utils.py | 112 ++++++---- kernels/src/kernels/variants.py | 333 ++++++++++++++++++++++++---- kernels/tests/test_variants.py | 322 +++++++++++++++++++++++++++ 5 files changed, 771 insertions(+), 122 deletions(-) create mode 100644 kernels/tests/test_variants.py diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index 51b322aa..f678f7dc 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -1,5 +1,6 @@ import ctypes import ctypes.util +import re import warnings from dataclasses import dataclass from typing import Optional, Protocol @@ -8,6 +9,11 @@ from kernels.compat import has_torch +_CUDA_VARIANT_REGEX = re.compile(r"cu(\d+)(\d+)") +_ROCM_VARIANT_REGEX = re.compile(r"rocm(\d+)(\d+)") +_XPU_VARIANT_REGEX = re.compile(r"xpu(\d+)(\d+)") +_CANN_VARIANT_REGEX = re.compile(r"cann(\d+)(\d+)") + class Backend(Protocol): @property @@ -18,7 +24,7 @@ def name(self) -> str: ... @property - def variant(self) -> str: + def variant_str(self) -> str: """ The name of the backend as used in a build variant, e.g. `cu128` for CUDA 12.8. @@ -35,9 +41,16 @@ def name(self) -> str: return "cann" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"cann{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "CANN": + m = _CANN_VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid CANN variant string: {s!r}") + return CANN(version=Version(f"{m.group(1)}.{m.group(2)}")) + @dataclass class CPU: @@ -46,9 +59,15 @@ def name(self) -> str: return "cpu" @property - def variant(self) -> str: + def variant_str(self) -> str: return "cpu" + @staticmethod + def parse(s: str) -> "CPU": + if s != "cpu": + raise ValueError(f"Invalid CPU variant string: {s!r}") + return CPU() + @dataclass class CUDA: @@ -59,9 +78,16 @@ def name(self) -> str: return "cuda" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"cu{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "CUDA": + m = _CUDA_VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid CUDA variant string: {s!r}") + return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}")) + @dataclass class Metal: @@ -70,9 +96,15 @@ def name(self) -> str: return "metal" @property - def variant(self) -> str: + def variant_str(self) -> str: return "metal" + @staticmethod + def parse(s: str) -> "Metal": + if s != "metal": + raise ValueError(f"Invalid Metal variant string: {s!r}") + return Metal() + @dataclass class Neuron: @@ -81,9 +113,15 @@ def name(self) -> str: return "neuron" @property - def variant(self) -> str: + def variant_str(self) -> str: return "neuron" + @staticmethod + def parse(s: str) -> "Neuron": + if s != "neuron": + raise ValueError(f"Invalid Neuron variant string: {s!r}") + return Neuron() + @dataclass class ROCm: @@ -94,9 +132,16 @@ def name(self) -> str: return "rocm" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"rocm{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "ROCm": + m = _ROCM_VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid ROCm variant string: {s!r}") + return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}")) + @dataclass class XPU: @@ -107,9 +152,36 @@ def name(self) -> str: return "xpu" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"xpu{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "XPU": + m = _XPU_VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid XPU variant string: {s!r}") + return XPU(version=Version(f"{m.group(1)}.{m.group(2)}")) + + +def parse_backend(s: str) -> Backend: + """Parse a backend variant string (e.g. 'cu128', 'rocm61', 'cpu') into a Backend.""" + if s == "cpu": + return CPU.parse(s) + elif s == "metal": + return Metal.parse(s) + elif s == "neuron": + return Neuron.parse(s) + elif s.startswith("cu"): + return CUDA.parse(s) + elif s.startswith("rocm"): + return ROCm.parse(s) + elif s.startswith("xpu"): + return XPU.parse(s) + elif s.startswith("cann"): + return CANN.parse(s) + else: + raise ValueError(f"Unknown backend variant string: {s!r}") + def _backend() -> Backend: if has_torch: diff --git a/kernels/src/kernels/cli/versions.py b/kernels/src/kernels/cli/versions.py index a752d789..80bc643c 100644 --- a/kernels/src/kernels/cli/versions.py +++ b/kernels/src/kernels/cli/versions.py @@ -1,44 +1,22 @@ -from importlib.util import find_spec -from pathlib import Path - -from huggingface_hub import HfApi - from kernels._versions import _get_available_versions -from kernels.utils import _build_variants, _get_hf_api -from kernels.variants import BUILD_VARIANT_REGEX +from kernels.utils import _get_hf_api +from kernels.variants import get_variants, resolve_variant def print_kernel_versions(repo_id: str): api = _get_hf_api() - if find_spec("torch") is None: - # Do not mark compatible variants when Torch is not available. - compatible_variants = set() - else: - compatible_variants = set(_build_variants(None)) - versions = _get_available_versions(repo_id).items() if not versions: print(f"Repository does not support kernel versions: {repo_id}") return for version, ref in sorted(versions, key=lambda x: x[0]): + variants = get_variants(api, repo_id=repo_id, revision=ref.ref) + best = resolve_variant(variants) print(f"Version {version}: ", end="") - variants = [ - f"{variant} ✅" if variant in compatible_variants else f"{variant}" - for variant in _get_build_variants(api, repo_id, ref.ref) + variant_strs = [ + f"{variant.variant_str} ✅" if variant == best else f"{variant.variant_str}" + for variant in variants ] - print(", ".join(variants)) - - -def _get_build_variants(api: "HfApi", repo_id: str, revision: str) -> list[str]: - variants = set() - for filename in api.list_repo_files(repo_id, revision=revision): - path = Path(filename) - if len(path.parts) < 2 or path.parts[0] != "build": - continue - - match = BUILD_VARIANT_REGEX.match(path.parts[1]) - if match: - variants.add(path.parts[1]) - return sorted(variants) + print(", ".join(variant_strs)) diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 56c166f0..729510a7 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -17,13 +17,18 @@ from kernels._system import glibc_version from kernels._versions import select_revision_or_version -from kernels.backends import _backend +from kernels.backends import _backend, _select_backend from kernels.compat import has_torch, has_tvm_ffi from kernels.deps import validate_dependencies from kernels.lockfile import KernelLock, VariantLock from kernels.metadata import Metadata from kernels.status import resolve_status -from kernels.variants import _build_variants +from kernels.variants import ( + Variant, + get_variants, + get_variants_local, + resolve_variant, +) KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"} @@ -128,7 +133,16 @@ def install_kernel( repo_id, revision = resolve_status(api, repo_id, revision) package_name = package_name_from_repo_id(repo_id) - allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)] + + variants = get_variants(api, repo_id=repo_id, revision=revision) + variant = resolve_variant(variants, backend) + + if variant is None: + raise FileNotFoundError( + f"Cannot find a build variant for this system in {repo_id} (revision: {revision}). Available variants: {', '.join([variant.variant_str for variant in variants])}" + ) + + allow_patterns = [f"build/{variant.variant_str}/*"] repo_path = Path( str( api.snapshot_download( @@ -143,7 +157,10 @@ def install_kernel( try: return _find_kernel_in_repo_path( - repo_path, package_name, backend=backend, variant_locks=variant_locks + repo_path, + package_name, + variant=variant, + variant_locks=variant_locks, ) except FileNotFoundError: raise FileNotFoundError( @@ -155,30 +172,21 @@ def _find_kernel_in_repo_path( repo_path: Path, package_name: str, *, - backend: str | None = None, + variant: Variant, variant_locks: dict[str, VariantLock] | None = None, ) -> tuple[str, Path]: - variants = _build_variants(backend) - variant = None - variant_path = None - for candidate_variant in variants: - variant_path = repo_path / "build" / candidate_variant - if variant_path.exists(): - variant = candidate_variant - break - - if variant is None: - raise FileNotFoundError( - f"Kernel at path `{repo_path}` does not have one of build variants: {', '.join(variants)}" - ) - - assert variant_path is not None + variant_str = variant.variant_str + variant_path = repo_path / "build" / variant_str + if not variant_path.exists(): + raise FileNotFoundError(f"Variant path does not exist: `{variant_path}`") if variant_locks is not None: - variant_lock = variant_locks.get(variant) + variant_lock = variant_locks.get(variant_str) if variant_lock is None: raise ValueError(f"No lock found for build variant: {variant}") - validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash) + validate_kernel( + repo_path=repo_path, variant=variant_str, hash=variant_lock.hash + ) module_init_path = variant_path / "__init__.py" if not os.path.exists(module_init_path): @@ -297,12 +305,12 @@ def get_local_kernel( Returns: `ModuleType`: The imported kernel module. """ - # Presume we were given the top level path of the kernel repository. for base_path in [repo_path, repo_path / "build"]: - for v in _build_variants(backend): - variant_path = base_path / v - if variant_path.exists(): - return _import_from_path(package_name, variant_path) + variants = get_variants_local(base_path) + variant = resolve_variant(variants, backend) + + if variant is not None: + return _import_from_path(package_name, base_path / variant.variant_str) # If we didn't find the package in the repo we may have a explicit # package path. @@ -341,12 +349,19 @@ def has_kernel( package_name = package_name_from_repo_id(repo_id) api = _get_hf_api() - for variant in _build_variants(backend): - for init_file in ["__init__.py", f"{package_name}/__init__.py"]: - if api.file_exists( - repo_id, revision=revision, filename=f"build/{variant}/{init_file}" - ): - return True + variants = get_variants(api, repo_id=repo_id, revision=revision) + variant = resolve_variant(variants, backend) + + if variant is None: + return False + + for init_file in ["__init__.py", f"{package_name}/__init__.py"]: + if api.file_exists( + repo_id, + revision=revision, + filename=f"build/{variant.variant_str}/{init_file}", + ): + return True return False @@ -388,7 +403,15 @@ def load_kernel( package_name = package_name_from_repo_id(repo_id) api = _get_hf_api() - allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)] + variants = get_variants(api, repo_id=repo_id, revision=locked_sha) + variant = resolve_variant(variants, backend) + + if variant is None: + raise FileNotFoundError( + f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}). Available variants: {', '.join([variant.variant_str for variant in variants])}" + ) + + allow_patterns = [f"build/{variant.variant_str}/*"] repo_path = Path( str( api.snapshot_download( @@ -403,7 +426,10 @@ def load_kernel( try: package_name, variant_path = _find_kernel_in_repo_path( - repo_path, package_name, backend=backend, variant_locks=None + repo_path, + package_name, + variant=variant, + variant_locks=None, ) return _import_from_path(package_name, variant_path) except FileNotFoundError: @@ -538,6 +564,18 @@ def package_name_from_repo_id(repo_id: str) -> str: return repo_id.split("/")[-1].replace("-", "_") +def _platform() -> str: + cpu = platform.machine() + os = platform.system().lower() + + if os == "darwin": + cpu = "aarch64" if cpu == "arm64" else cpu + elif os == "windows": + cpu = "x86_64" if cpu == "AMD64" else cpu + + return f"{cpu}-{os}" + + def _get_hf_api(user_agent: str | dict | None = None) -> HfApi: """Returns an instance of HfApi with proper settings.""" @@ -553,8 +591,8 @@ def _get_hf_api(user_agent: str | dict | None = None) -> HfApi: # System info python = ".".join(platform.python_version_tuple()[:2]) - variants = ":".join(_build_variants(None)) - user_agent_str += f"; kernels/{__version__}; python/{python}; build_variant/{variants}; file_type/kernel" + backend = _select_backend(None).variant_str + user_agent_str += f"; kernels/{__version__}; python/{python}; backend/{backend}; flatform/{_platform()}; file_type/kernel" if has_torch: import torch diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 33f6ee76..8b198bd9 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -1,85 +1,324 @@ import platform import re +from dataclasses import dataclass +from pathlib import Path -from packaging.version import parse +from huggingface_hub import HfApi +from huggingface_hub.hf_api import RepoFolder +from packaging.version import Version, parse -from kernels.backends import _select_backend +from kernels.backends import ( + CANN, + CUDA, + XPU, + Backend, + ROCm, + _select_backend, + parse_backend, +) from kernels.compat import has_torch, has_tvm_ffi BUILD_VARIANT_REGEX = re.compile( r"^(torch\d+\d+|torch-(cpu|cuda|metal|neuron|rocm|xpu)|tvm-ffi\d+\d+)" ) +_TORCH_VARIANT_REGEX = re.compile(r"torch(\d+?)(\d+)") +_TVM_FFI_VARIANT_REGEX = re.compile(r"tvm-ffi(\d+?)(\d+)") + + +@dataclass(unsafe_hash=True) +class Torch: + version: Version | None + + @property + def variant_str(self) -> str: + if self.version is None: + return "torch" + return f"torch{self.version.major}{self.version.minor}" + + @staticmethod + def parse(s: str) -> "Torch": + if s == "torch": + return Torch(version=None) + m = _TORCH_VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid Torch variant string: {s!r}") + return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) + + +@dataclass(unsafe_hash=True) +class TvmFfi: + version: Version + + @property + def variant_str(self) -> str: + return f"tvm-ffi{self.version.major}{self.version.minor}" + + @staticmethod + def parse(s: str) -> "TvmFfi": + m = _TVM_FFI_VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid TvmFfi variant string: {s!r}") + return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) + + +@dataclass +class Arch: + backend: Backend + platform: str + os: str + cxx11_abi: bool | None + + @property + def variant_str(self) -> str: + if self.cxx11_abi is None: + return f"{self.backend.variant_str}-{self.platform}-{self.os}" + else: + return f"{'cxx11' if self.cxx11_abi else 'cxx98'}-{self.backend.variant_str}-{self.platform}-{self.os}" + + @staticmethod + def parse(parts: list[str]) -> "Arch": + # Handle Linux with cxx11 marker. + if len(parts) == 4: + cxx11_abi = parts[0] == "cxx11" + parts = parts[1:] + elif len(parts) == 3: + cxx11_abi = None + else: + raise ValueError(f"Invalid arch variant parts: {parts!r}") + + backend = parse_backend(parts[0]) + platform = parts[1] + os = parts[2] + + return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) + + +@dataclass +class Noarch: + backend_name: str + + @property + def variant_str(self) -> str: + return self.backend_name + + @staticmethod + def parse(s: str) -> "Noarch": + return Noarch(backend_name=s) + -def _torch_build_variant(backend: str | None) -> list[str]: - if not has_torch: +@dataclass +class Variant: + framework: Torch | TvmFfi + arch: Arch | Noarch + + @property + def variant_str(self) -> str: + return f"{self.framework.variant_str}-{self.arch.variant_str}" + + @staticmethod + def parse(variant_str: str) -> "Variant": + parts = variant_str.split("-") + + arch: Arch | Noarch + framework: Torch | TvmFfi + + if parts[0] == "torch": + # noarch: e.g. "torch-cpu" + framework = Torch.parse(parts[0]) + arch = Noarch.parse("-".join(parts[1:])) + elif parts[0].startswith("torch"): + framework = Torch.parse(parts[0]) + arch = Arch.parse(parts[1:]) + elif parts[0] == "tvm" and parts[1].startswith("ffi"): + framework = TvmFfi.parse(f"tvm-{parts[1]}") + arch = Arch.parse(parts[2:]) + else: + raise ValueError(f"Unknown framework in variant string: {variant_str!r}") + + return Variant(framework=framework, arch=arch) + + +def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: + """Get all the build variants available from a kernel repository.""" + + try: + tree = api.list_repo_tree(repo_id, path_in_repo="build", revision=revision) + variant_strs = { + item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder) + } + except Exception: return [] - selected_backend = _select_backend(backend) + variants = [] + for variant_str in variant_strs: + try: + variants.append(Variant.parse(variant_str)) + except ValueError: + pass + return variants + - backend_variant = selected_backend.variant +def get_variants_local(repo_path: Path) -> list[Variant]: + """Get all the build variants available in a local directory.""" - import torch + try: + variant_strs = {entry.name for entry in repo_path.iterdir() if entry.is_dir()} + except Exception: + return [] + + variants = [] + for variant_str in variant_strs: + try: + variants.append(Variant.parse(variant_str)) + except ValueError: + pass + return variants + + +def resolve_variant( + variants: list[Variant], backend: str | None = None +) -> Variant | None: + """Return the best matching variant for the current system.""" + selected_backend = _select_backend(backend) - torch_version = parse(torch.__version__) cpu = platform.machine() os = platform.system().lower() if os == "darwin": cpu = "aarch64" if cpu == "arm64" else cpu - return [ - f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}" - ] elif os == "windows": cpu = "x86_64" if cpu == "AMD64" else cpu - return [ - f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}" - ] - cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" - return [ - f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{backend_variant}-{cpu}-{os}" - ] + torch_version = None + torch_cxx11_abi = None + if has_torch: + import torch + # Parse Torch version and strip patch/tags. + torch_version = parse(torch.__version__) + torch_version = Version(f"{torch_version.major}.{torch_version.minor}") -def _tvm_ffi_build_variant(backend: str | None) -> list[str]: - if not has_tvm_ffi: - return [] + torch_cxx11_abi = torch.compiled_with_cxx11_abi() if os == "linux" else None - selected_backend = _select_backend(backend) + tvm_ffi_version = None + if has_tvm_ffi: + import tvm_ffi - backend_variant = selected_backend.variant + # Parse Torch version and strip patch/tags. + tvm_ffi_version = parse(tvm_ffi.__version__) + tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}") - import tvm_ffi + return _resolve_variant_for_system( + variants=variants, + selected_backend=selected_backend, + cpu=cpu, + os=os, + torch_version=torch_version, + torch_cxx11_abi=torch_cxx11_abi, + tvm_ffi_version=tvm_ffi_version, + ) - tvm_ffi_version = parse(tvm_ffi.__version__) - cpu = platform.machine() - os = platform.system().lower() - return [ - f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{backend_variant}-{cpu}-{os}" - ] +def _resolve_variant_for_system( + variants: list[Variant], + selected_backend: Backend, + cpu: str, + os: str, + torch_version: Version | None, + torch_cxx11_abi: bool | None, + tvm_ffi_version: Version | None, +) -> Variant | None: + """Resolve the best matching variant given explicit system parameters.""" + applicable = _filter_variants( + variants, + selected_backend, + cpu, + os, + torch_version, + torch_cxx11_abi, + tvm_ffi_version, + ) + sorted_variants = _sort_variants(applicable) + return sorted_variants[0] if sorted_variants else None -def _build_variant_noarch(backend: str | None) -> list[str]: - selected_backend = _select_backend(backend) +def _filter_variants( + variants: list[Variant], + selected_backend: Backend, + cpu: str, + os: str, + torch_version: Version | None, + torch_cxx11_abi: bool | None, + tvm_ffi_version: Version | None, +) -> list[Variant]: + """Return only the variants applicable to the current system.""" + result = [] + for v in variants: + if isinstance(v.arch, Arch): + # Skip non-matching CPU or OS. + if v.arch.platform != cpu or v.arch.os != os: + continue + # If the variant is a Torch or tvm-ffi variant, check that it has the + # correct version and ABI. + if isinstance(v.framework, Torch): + if v.framework.version != torch_version: + continue + if v.arch.cxx11_abi != torch_cxx11_abi: + continue + elif isinstance(v.framework, TvmFfi): + if v.framework.version != tvm_ffi_version: + continue + # Given a system CUDA version of x.y, only CUDA versions x.z, + # where z <= y qualify. Otherwise, the backend + version (if present) + # must match. + if isinstance(selected_backend, CUDA) and isinstance(v.arch.backend, CUDA): + if ( + v.arch.backend.version.major != selected_backend.version.major + or v.arch.backend.version.minor > selected_backend.version.minor + ): + continue + elif v.arch.backend.variant_str != selected_backend.variant_str: + continue + else: + assert isinstance(v.arch, Noarch) + # Only noarch variants with a matching backend or "universal" + # are applicable. + noarch_backend_name = ( + "npu" if selected_backend.name == "cann" else selected_backend.name + ) + if ( + v.arch.backend_name != noarch_backend_name + and v.arch.backend_name != "universal" + ): + continue + result.append(v) + return result - if selected_backend.name == "cann": - return ["torch-npu"] - else: - return [f"torch-{selected_backend.name}"] +def _sort_variants( + variants: list[Variant], +) -> list[Variant]: + """Sort variants in preference order: -def _build_variant_universal() -> list[str]: - # Once we support other frameworks, detection goes here. - return ["torch-universal"] if has_torch else [] + 1. Torch arch kernels with with the highest compatible CUDA version. + 2. tvm-ffi arch kernels with with the highest compatible CUDA version. + 3. Torch noarch kernels. + 4. Old Torch universal kernels. + """ + def sort_key(v: Variant) -> tuple: + if isinstance(v.arch, Arch): + framework_order = 0 if isinstance(v.framework, Torch) else 1 + if isinstance(v.arch.backend, (CUDA, ROCm, XPU, CANN)): + # Order by backend version in reverse (higher is better). + backend_order = -v.arch.backend.version.minor + else: + backend_order = 0 + return (framework_order, backend_order) + else: + assert isinstance(v.arch, Noarch) + universal_order = 1 if v.arch.backend_name == "universal" else 0 + return (2, universal_order) -def _build_variants(backend: str | None) -> list[str]: - """Return compatible build variants in preferred order.""" - return [ - *_torch_build_variant(backend), - *_tvm_ffi_build_variant(backend), - *_build_variant_noarch(backend), - *_build_variant_universal(), - ] + return sorted(variants, key=sort_key) diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py new file mode 100644 index 00000000..faf29699 --- /dev/null +++ b/kernels/tests/test_variants.py @@ -0,0 +1,322 @@ +import pytest +from huggingface_hub import HfApi +from packaging.version import Version + +from kernels.backends import CPU, CUDA, ROCm +from kernels.variants import Variant, _resolve_variant_for_system, get_variants + +VARIANT_STRINGS = [ + "torch29-cpu-aarch64-darwin", + "torch29-cxx11-cpu-aarch64-linux", + "torch29-cxx11-cpu-x86_64-linux", + "torch29-cxx11-cu126-aarch64-linux", + "torch29-cxx11-cu126-x86_64-linux", + "torch29-cxx11-cu128-aarch64-linux", + "torch29-cxx11-cu128-x86_64-linux", + "torch29-cxx11-cu130-aarch64-linux", + "torch29-cxx11-cu130-x86_64-linux", + "torch29-cxx11-rocm63-x86_64-linux", + "torch29-cxx11-rocm64-x86_64-linux", + "torch29-cxx11-xpu20252-x86_64-linux", + "torch29-metal-aarch64-darwin", + "torch210-cpu-aarch64-darwin", + "torch210-cu128-x86_64-windows", + "torch210-cxx11-cpu-aarch64-linux", + "torch210-cxx11-cpu-x86_64-linux", + "torch210-cxx11-cu126-aarch64-linux", + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu128-aarch64-linux", + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu130-aarch64-linux", + "torch210-cxx11-cu130-x86_64-linux", + "torch210-cxx11-rocm70-x86_64-linux", + "torch210-cxx11-rocm71-x86_64-linux", + "torch210-cxx11-xpu20253-x86_64-linux", + "torch210-metal-aarch64-darwin", + "torch210-xpu20253-x86_64-windows", +] + + +NOARCH_VARIANT_STRINGS = [ + "torch-cpu", + "torch-cuda", + "torch-metal", + "torch-neuron", + "torch-rocm", + "torch-xpu", + "torch-npu", + "torch-universal", +] + + +@pytest.mark.parametrize("variant_str", VARIANT_STRINGS) +def test_arch_variants(variant_str: str): + # Roundtrip parse and generate variant string. + assert Variant.parse(variant_str).variant_str == variant_str + + +@pytest.mark.parametrize("variant_str", NOARCH_VARIANT_STRINGS) +def test_noarch_variants(variant_str: str): + # Roundtrip parse and generate variant string. + assert Variant.parse(variant_str).variant_str == variant_str + + +def test_get_variants(): + api = HfApi() + variants = get_variants(api, repo_id="kernels-community/relu", revision="v1") + variant_strs = {v.variant_str for v in variants} + # Superset because new variants may be added in the future. + assert variant_strs.issuperset(VARIANT_STRINGS) + + +RESOLVE_VARIANTS = [ + Variant.parse(s) + for s in [ + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu130-x86_64-linux", + "torch210-cxx11-rocm70-x86_64-linux", + "torch210-cxx11-cpu-x86_64-linux", + "torch210-cpu-aarch64-darwin", + "torch210-metal-aarch64-darwin", + "torch-cuda", + "torch-cpu", + ] +] + + +def test_resolve_cuda_exact(): + # CUDA 12.8 should resolve to cu128. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cxx11-cu128-x86_64-linux" + + +def test_resolve_cuda_best_older_minor(): + # CUDA 12.9 is not available, should fall back to cu128 (highest <= 12.9). + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.9")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cxx11-cu128-x86_64-linux" + + +def test_resolve_cuda_no_newer_minor(): + # CUDA 12.5 is older than all the variants, fall back to noarch. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.5")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch-cuda" + + +def test_resolve_cuda_no_different_major(): + # Different major version must not match. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("11.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch-cuda" + + +def test_resolve_rocm(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=ROCm(Version("7.0")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cxx11-rocm70-x86_64-linux" + + +def test_resolve_cpu_linux(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CPU(), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cxx11-cpu-x86_64-linux" + + +def test_resolve_cpu_darwin(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CPU(), + cpu="aarch64", + os="darwin", + torch_version=Version("2.10"), + torch_cxx11_abi=None, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cpu-aarch64-darwin" + + +def test_resolve_metal_darwin(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CPU(), + cpu="aarch64", + os="darwin", + torch_version=Version("2.10"), + torch_cxx11_abi=None, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cpu-aarch64-darwin" + + +def test_resolve_noarch_fallback(): + # With no matching arch variant, should fall back to torch noarch. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.8")), + cpu="aarch64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch-cuda" + + +def test_resolve_no_match(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=ROCm(Version("7.0")), + cpu="x86_64", + os="linux", + torch_version=Version("2.9"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is None + + +RESOLVE_VARIANTS_UNIVERSAL = [ + Variant.parse(s) + for s in [ + "torch210-cxx11-cu128-x86_64-linux", + "torch-universal", + ] +] + + +def test_resolve_universal_matches_any_backend(): + # Universal works with every backend. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_UNIVERSAL, + selected_backend=ROCm(Version("7.0")), + cpu="x86_64", + os="linux", + torch_version=Version("2.9"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch-universal" + + +def test_resolve_universal_is_last_resort(): + # Specific match is preferred over universal. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_UNIVERSAL, + selected_backend=CUDA(Version("12.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch210-cxx11-cu128-x86_64-linux" + + +def test_resolve_specific_noarch_preferred_over_universal(): + # Backend-specific noarch is preferred over universal. + variants = [Variant.parse(s) for s in ["torch-universal", "torch-cuda"]] + result = _resolve_variant_for_system( + variants=variants, + selected_backend=CUDA(Version("12.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.9"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is not None + assert result.variant_str == "torch-cuda" + + +RESOLVE_VARIANTS_NO_NOARCH = [ + Variant.parse(s) + for s in [ + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu130-x86_64-linux", + ] +] + + +def test_resolve_cuda_no_newer_minor_no_noarch(): + # No compatible variant for 12.5. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_NO_NOARCH, + selected_backend=CUDA(Version("12.5")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is None + + +def test_resolve_cuda_no_different_major_no_noarch(): + # 11.8 has a different major, so there is no compatible fallback. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_NO_NOARCH, + selected_backend=CUDA(Version("11.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result is None From 38696ec186d431d13ac8ded2d7b35d7f9cbc48da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 12 Mar 2026 13:54:33 +0000 Subject: [PATCH 2/9] Refactor variant handling and add CUDA fallback Build variants were stringly-typed throughout kernels, with custom parsing and serialization sprinkled everywhere. This change adds proper/strong typing to variants adding a `Variant` class. This also centers parsing/serialization in one place and allows code to easily query various parts of of a variant. This also fundamentally changes how we deal with getting variants from the Hub. Rather than casting a wide net with all possible variants and using allow patterns based on that, we query the hub for variants of a kernel, parse them and can decide if there is an applicable variant ahead of time. If there are multiple applicable variants, we can select the best one (e.g. arch before noarch or recent CUDA version before older versions). --- kernels/src/kernels/variants.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 8b198bd9..f9f0389f 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -64,6 +64,8 @@ def parse(s: str) -> "TvmFfi": @dataclass class Arch: + """Aarch kernel information.""" + backend: Backend platform: str os: str @@ -96,6 +98,8 @@ def parse(parts: list[str]) -> "Arch": @dataclass class Noarch: + """Noarch kernel information.""" + backend_name: str @property @@ -109,6 +113,8 @@ def parse(s: str) -> "Noarch": @dataclass class Variant: + """Kernel build variant.""" + framework: Torch | TvmFfi arch: Arch | Noarch From 88ad4e41d0b86df9e9f2805e7fc63d8f42d3f0d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 09:42:59 +0000 Subject: [PATCH 3/9] Switch around cxx11 condition to support tagless build variant --- kernels/src/kernels/variants.py | 4 +++- kernels/tests/test_variants.py | 35 ++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index f9f0389f..61125ee3 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -82,7 +82,9 @@ def variant_str(self) -> str: def parse(parts: list[str]) -> "Arch": # Handle Linux with cxx11 marker. if len(parts) == 4: - cxx11_abi = parts[0] == "cxx11" + # In the future, we want to remove the marker and use cxx11 as + # the default. We check on cxx98 for this reason. + cxx11_abi = parts[0] != "cxx98" parts = parts[1:] elif len(parts) == 3: cxx11_abi = None diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py index faf29699..559883fb 100644 --- a/kernels/tests/test_variants.py +++ b/kernels/tests/test_variants.py @@ -6,6 +6,8 @@ from kernels.variants import Variant, _resolve_variant_for_system, get_variants VARIANT_STRINGS = [ + "torch25-cxx98-cu118-aarch64-linux", + "torch25-cxx98-cpu-x86_64-linux", "torch29-cpu-aarch64-darwin", "torch29-cxx11-cpu-aarch64-linux", "torch29-cxx11-cpu-x86_64-linux", @@ -48,6 +50,37 @@ "torch-universal", ] +SUPERSET_VARIANT_STRINGS = [ + "torch29-cpu-aarch64-darwin", + "torch29-cxx11-cpu-aarch64-linux", + "torch29-cxx11-cpu-x86_64-linux", + "torch29-cxx11-cu126-aarch64-linux", + "torch29-cxx11-cu126-x86_64-linux", + "torch29-cxx11-cu128-aarch64-linux", + "torch29-cxx11-cu128-x86_64-linux", + "torch29-cxx11-cu130-aarch64-linux", + "torch29-cxx11-cu130-x86_64-linux", + "torch29-cxx11-rocm63-x86_64-linux", + "torch29-cxx11-rocm64-x86_64-linux", + "torch29-cxx11-xpu20252-x86_64-linux", + "torch29-metal-aarch64-darwin", + "torch210-cpu-aarch64-darwin", + "torch210-cu128-x86_64-windows", + "torch210-cxx11-cpu-aarch64-linux", + "torch210-cxx11-cpu-x86_64-linux", + "torch210-cxx11-cu126-aarch64-linux", + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu128-aarch64-linux", + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu130-aarch64-linux", + "torch210-cxx11-cu130-x86_64-linux", + "torch210-cxx11-rocm70-x86_64-linux", + "torch210-cxx11-rocm71-x86_64-linux", + "torch210-cxx11-xpu20253-x86_64-linux", + "torch210-metal-aarch64-darwin", + "torch210-xpu20253-x86_64-windows", +] + @pytest.mark.parametrize("variant_str", VARIANT_STRINGS) def test_arch_variants(variant_str: str): @@ -66,7 +99,7 @@ def test_get_variants(): variants = get_variants(api, repo_id="kernels-community/relu", revision="v1") variant_strs = {v.variant_str for v in variants} # Superset because new variants may be added in the future. - assert variant_strs.issuperset(VARIANT_STRINGS) + assert variant_strs.issuperset(SUPERSET_VARIANT_STRINGS) RESOLVE_VARIANTS = [ From 81bbbd276ca50d672e7079604c07b4334df8ffc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 10:18:50 +0000 Subject: [PATCH 4/9] Make `kernels versions` more informative --- docs/source/cli-versions.md | 5 +++-- kernels/src/kernels/backends.py | 14 +++++++------- kernels/src/kernels/cli/versions.py | 15 ++++++++++++--- kernels/src/kernels/variants.py | 20 ++++++++++++++------ 4 files changed, 36 insertions(+), 18 deletions(-) diff --git a/docs/source/cli-versions.md b/docs/source/cli-versions.md index fcbcdb49..d8c7b80a 100644 --- a/docs/source/cli-versions.md +++ b/docs/source/cli-versions.md @@ -1,6 +1,7 @@ # kernels versions -Use `kernels versions` to list all available versions of a kernel on the Hub. +Use `kernels versions` to list all available versions of a kernel on the Hub +and marks compatible versions. ## Usage @@ -19,7 +20,7 @@ kernels versions kernels-community/activation ## Example Output ```text -Version 1: torch210-cu128-x86_64-windows, torch210-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch210-metal-aarch64-darwin ✅, torch27-cxx11-cu118-x86_64-linux, torch27-cxx11-cu126-x86_64-linux, torch27-cxx11-cu128-aarch64-linux, torch27-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu126-x86_64-linux, torch28-cxx11-cu128-aarch64-linux, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu129-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, torch29-cxx11-cu126-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu128-x86_64-linux, torch29-cxx11-cu130-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch29-metal-aarch64-darwin +Version 1: torch210-metal-aarch64-darwin, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch27-cxx11-cu118-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-aarch64-linux, torch27-cxx11-cu126-x86_64-linux, ✅ torch29-cxx11-cu126-x86_64-linux (compatible), torch27-cxx11-cu128-x86_64-linux, torch210-cxx11-cu126-x86_64-linux, torch29-metal-aarch64-darwin, torch27-cxx11-cu128-aarch64-linux, torch210-cu128-x86_64-windows, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, ✅ torch29-cxx11-cu128-x86_64-linux (preferred), torch28-cxx11-cu129-x86_64-linux ``` ## Use Cases diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index f678f7dc..22432781 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -32,7 +32,7 @@ def variant_str(self) -> str: ... -@dataclass +@dataclass(unsafe_hash=True) class CANN: version: Version @@ -52,7 +52,7 @@ def parse(s: str) -> "CANN": return CANN(version=Version(f"{m.group(1)}.{m.group(2)}")) -@dataclass +@dataclass(unsafe_hash=True) class CPU: @property def name(self) -> str: @@ -69,7 +69,7 @@ def parse(s: str) -> "CPU": return CPU() -@dataclass +@dataclass(unsafe_hash=True) class CUDA: version: Version @@ -89,7 +89,7 @@ def parse(s: str) -> "CUDA": return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}")) -@dataclass +@dataclass(unsafe_hash=True) class Metal: @property def name(self) -> str: @@ -106,7 +106,7 @@ def parse(s: str) -> "Metal": return Metal() -@dataclass +@dataclass(unsafe_hash=True) class Neuron: @property def name(self) -> str: @@ -123,7 +123,7 @@ def parse(s: str) -> "Neuron": return Neuron() -@dataclass +@dataclass(unsafe_hash=True) class ROCm: version: Version @@ -143,7 +143,7 @@ def parse(s: str) -> "ROCm": return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}")) -@dataclass +@dataclass(unsafe_hash=True) class XPU: version: Version diff --git a/kernels/src/kernels/cli/versions.py b/kernels/src/kernels/cli/versions.py index 80bc643c..be9d716d 100644 --- a/kernels/src/kernels/cli/versions.py +++ b/kernels/src/kernels/cli/versions.py @@ -1,6 +1,9 @@ from kernels._versions import _get_available_versions from kernels.utils import _get_hf_api -from kernels.variants import get_variants, resolve_variant +from kernels.variants import ( + get_variants, + resolve_variants, +) def print_kernel_versions(repo_id: str): @@ -13,10 +16,16 @@ def print_kernel_versions(repo_id: str): for version, ref in sorted(versions, key=lambda x: x[0]): variants = get_variants(api, repo_id=repo_id, revision=ref.ref) - best = resolve_variant(variants) + resolved = resolve_variants(variants, None) + best = resolved[0] if resolved else None + resolved = set(resolved) print(f"Version {version}: ", end="") variant_strs = [ - f"{variant.variant_str} ✅" if variant == best else f"{variant.variant_str}" + ( + f"✅ {variant.variant_str} ({'compatible, preferred' if variant == best else 'compatible'})" + if variant in resolved + else f"{variant.variant_str}" + ) for variant in variants ] print(", ".join(variant_strs)) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 61125ee3..ee05e149 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -62,7 +62,7 @@ def parse(s: str) -> "TvmFfi": return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) -@dataclass +@dataclass(unsafe_hash=True) class Arch: """Aarch kernel information.""" @@ -98,7 +98,7 @@ def parse(parts: list[str]) -> "Arch": return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) -@dataclass +@dataclass(unsafe_hash=True) class Noarch: """Noarch kernel information.""" @@ -113,7 +113,7 @@ def parse(s: str) -> "Noarch": return Noarch(backend_name=s) -@dataclass +@dataclass(unsafe_hash=True) class Variant: """Kernel build variant.""" @@ -188,6 +188,15 @@ def resolve_variant( variants: list[Variant], backend: str | None = None ) -> Variant | None: """Return the best matching variant for the current system.""" + resolved = resolve_variants(variants, backend) + return resolved[0] if resolved else None + + +def resolve_variants( + variants: list[Variant], backend: str | None = None +) -> list[Variant]: + """Return the matching variants for the current system, sorted + by decreasing order of preference.""" selected_backend = _select_backend(backend) cpu = platform.machine() @@ -236,7 +245,7 @@ def _resolve_variant_for_system( torch_version: Version | None, torch_cxx11_abi: bool | None, tvm_ffi_version: Version | None, -) -> Variant | None: +) -> list[Variant]: """Resolve the best matching variant given explicit system parameters.""" applicable = _filter_variants( variants, @@ -247,8 +256,7 @@ def _resolve_variant_for_system( torch_cxx11_abi, tvm_ffi_version, ) - sorted_variants = _sort_variants(applicable) - return sorted_variants[0] if sorted_variants else None + return _sort_variants(applicable) def _filter_variants( From 0a4b48e11dcc91edeffcffa82af7c4d681b813ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 10:44:16 +0000 Subject: [PATCH 5/9] Move backend/variant regexes to their classes --- kernels/src/kernels/backends.py | 23 +++++++++++++---------- kernels/src/kernels/variants.py | 12 +++++++----- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index 22432781..cc14ba5e 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -3,17 +3,12 @@ import re import warnings from dataclasses import dataclass -from typing import Optional, Protocol +from typing import ClassVar, Optional, Protocol from packaging.version import Version from kernels.compat import has_torch -_CUDA_VARIANT_REGEX = re.compile(r"cu(\d+)(\d+)") -_ROCM_VARIANT_REGEX = re.compile(r"rocm(\d+)(\d+)") -_XPU_VARIANT_REGEX = re.compile(r"xpu(\d+)(\d+)") -_CANN_VARIANT_REGEX = re.compile(r"cann(\d+)(\d+)") - class Backend(Protocol): @property @@ -34,6 +29,8 @@ def variant_str(self) -> str: @dataclass(unsafe_hash=True) class CANN: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)") + version: Version @property @@ -46,7 +43,7 @@ def variant_str(self) -> str: @staticmethod def parse(s: str) -> "CANN": - m = _CANN_VARIANT_REGEX.fullmatch(s) + m = CANN._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid CANN variant string: {s!r}") return CANN(version=Version(f"{m.group(1)}.{m.group(2)}")) @@ -71,6 +68,8 @@ def parse(s: str) -> "CPU": @dataclass(unsafe_hash=True) class CUDA: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)") + version: Version @property @@ -83,7 +82,7 @@ def variant_str(self) -> str: @staticmethod def parse(s: str) -> "CUDA": - m = _CUDA_VARIANT_REGEX.fullmatch(s) + m = CUDA._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid CUDA variant string: {s!r}") return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}")) @@ -125,6 +124,8 @@ def parse(s: str) -> "Neuron": @dataclass(unsafe_hash=True) class ROCm: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)") + version: Version @property @@ -137,7 +138,7 @@ def variant_str(self) -> str: @staticmethod def parse(s: str) -> "ROCm": - m = _ROCM_VARIANT_REGEX.fullmatch(s) + m = ROCm._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid ROCm variant string: {s!r}") return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}")) @@ -145,6 +146,8 @@ def parse(s: str) -> "ROCm": @dataclass(unsafe_hash=True) class XPU: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)") + version: Version @property @@ -157,7 +160,7 @@ def variant_str(self) -> str: @staticmethod def parse(s: str) -> "XPU": - m = _XPU_VARIANT_REGEX.fullmatch(s) + m = XPU._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid XPU variant string: {s!r}") return XPU(version=Version(f"{m.group(1)}.{m.group(2)}")) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index ee05e149..32f76ecc 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -2,6 +2,7 @@ import re from dataclasses import dataclass from pathlib import Path +from typing import ClassVar from huggingface_hub import HfApi from huggingface_hub.hf_api import RepoFolder @@ -22,12 +23,11 @@ r"^(torch\d+\d+|torch-(cpu|cuda|metal|neuron|rocm|xpu)|tvm-ffi\d+\d+)" ) -_TORCH_VARIANT_REGEX = re.compile(r"torch(\d+?)(\d+)") -_TVM_FFI_VARIANT_REGEX = re.compile(r"tvm-ffi(\d+?)(\d+)") - @dataclass(unsafe_hash=True) class Torch: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)") + version: Version | None @property @@ -40,7 +40,7 @@ def variant_str(self) -> str: def parse(s: str) -> "Torch": if s == "torch": return Torch(version=None) - m = _TORCH_VARIANT_REGEX.fullmatch(s) + m = Torch._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid Torch variant string: {s!r}") return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) @@ -48,6 +48,8 @@ def parse(s: str) -> "Torch": @dataclass(unsafe_hash=True) class TvmFfi: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)") + version: Version @property @@ -56,7 +58,7 @@ def variant_str(self) -> str: @staticmethod def parse(s: str) -> "TvmFfi": - m = _TVM_FFI_VARIANT_REGEX.fullmatch(s) + m = TvmFfi._VARIANT_REGEX.fullmatch(s) if not m: raise ValueError(f"Invalid TvmFfi variant string: {s!r}") return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) From 202712c49d912d6e24b17940c7e6836d97ccfa7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 10:53:29 +0000 Subject: [PATCH 6/9] Type fixes --- kernels/src/kernels/cli/versions.py | 4 +-- kernels/tests/test_variants.py | 54 ++++++++++++++--------------- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/kernels/src/kernels/cli/versions.py b/kernels/src/kernels/cli/versions.py index be9d716d..daa4656c 100644 --- a/kernels/src/kernels/cli/versions.py +++ b/kernels/src/kernels/cli/versions.py @@ -18,12 +18,12 @@ def print_kernel_versions(repo_id: str): variants = get_variants(api, repo_id=repo_id, revision=ref.ref) resolved = resolve_variants(variants, None) best = resolved[0] if resolved else None - resolved = set(resolved) + resolved_set = set(resolved) print(f"Version {version}: ", end="") variant_strs = [ ( f"✅ {variant.variant_str} ({'compatible, preferred' if variant == best else 'compatible'})" - if variant in resolved + if variant in resolved_set else f"{variant.variant_str}" ) for variant in variants diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py index 559883fb..47a3bdfe 100644 --- a/kernels/tests/test_variants.py +++ b/kernels/tests/test_variants.py @@ -129,8 +129,8 @@ def test_resolve_cuda_exact(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cxx11-cu128-x86_64-linux" + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cu128-x86_64-linux" def test_resolve_cuda_best_older_minor(): @@ -144,8 +144,8 @@ def test_resolve_cuda_best_older_minor(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cxx11-cu128-x86_64-linux" + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cu128-x86_64-linux" def test_resolve_cuda_no_newer_minor(): @@ -159,8 +159,8 @@ def test_resolve_cuda_no_newer_minor(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch-cuda" + assert result != [] + assert result[0].variant_str == "torch-cuda" def test_resolve_cuda_no_different_major(): @@ -174,8 +174,8 @@ def test_resolve_cuda_no_different_major(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch-cuda" + assert result != [] + assert result[0].variant_str == "torch-cuda" def test_resolve_rocm(): @@ -188,8 +188,8 @@ def test_resolve_rocm(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cxx11-rocm70-x86_64-linux" + assert result != [] + assert result[0].variant_str == "torch210-cxx11-rocm70-x86_64-linux" def test_resolve_cpu_linux(): @@ -202,8 +202,8 @@ def test_resolve_cpu_linux(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cxx11-cpu-x86_64-linux" + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cpu-x86_64-linux" def test_resolve_cpu_darwin(): @@ -216,8 +216,8 @@ def test_resolve_cpu_darwin(): torch_cxx11_abi=None, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cpu-aarch64-darwin" + assert result != [] + assert result[0].variant_str == "torch210-cpu-aarch64-darwin" def test_resolve_metal_darwin(): @@ -230,8 +230,8 @@ def test_resolve_metal_darwin(): torch_cxx11_abi=None, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cpu-aarch64-darwin" + assert result != [] + assert result[0].variant_str == "torch210-cpu-aarch64-darwin" def test_resolve_noarch_fallback(): @@ -245,8 +245,8 @@ def test_resolve_noarch_fallback(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch-cuda" + assert result != [] + assert result[0].variant_str == "torch-cuda" def test_resolve_no_match(): @@ -259,7 +259,7 @@ def test_resolve_no_match(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is None + assert result == [] RESOLVE_VARIANTS_UNIVERSAL = [ @@ -282,8 +282,8 @@ def test_resolve_universal_matches_any_backend(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch-universal" + assert result != [] + assert result[0].variant_str == "torch-universal" def test_resolve_universal_is_last_resort(): @@ -297,8 +297,8 @@ def test_resolve_universal_is_last_resort(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch210-cxx11-cu128-x86_64-linux" + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cu128-x86_64-linux" def test_resolve_specific_noarch_preferred_over_universal(): @@ -313,8 +313,8 @@ def test_resolve_specific_noarch_preferred_over_universal(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is not None - assert result.variant_str == "torch-cuda" + assert result != [] + assert result[0].variant_str == "torch-cuda" RESOLVE_VARIANTS_NO_NOARCH = [ @@ -338,7 +338,7 @@ def test_resolve_cuda_no_newer_minor_no_noarch(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is None + assert result == [] def test_resolve_cuda_no_different_major_no_noarch(): @@ -352,4 +352,4 @@ def test_resolve_cuda_no_different_major_no_noarch(): torch_cxx11_abi=True, tvm_ffi_version=None, ) - assert result is None + assert result == [] From 5be720f40fad66c64266382882ce509974dda914 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 10:58:47 +0000 Subject: [PATCH 7/9] Improve error handling --- kernels/src/kernels/variants.py | 16 +++++++++------- kernels/tests/test_basic.py | 1 - 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 32f76ecc..92de7911 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -1,3 +1,4 @@ +import logging import platform import re from dataclasses import dataclass @@ -152,19 +153,20 @@ def parse(variant_str: str) -> "Variant": def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: """Get all the build variants available from a kernel repository.""" - try: - tree = api.list_repo_tree(repo_id, path_in_repo="build", revision=revision) - variant_strs = { - item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder) - } - except Exception: - return [] + tree = api.list_repo_tree(repo_id, path_in_repo="build", revision=revision) + variant_strs = { + item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder) + } variants = [] for variant_str in variant_strs: try: variants.append(Variant.parse(variant_str)) except ValueError: + logging.warning( + f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}" + ) + log pass return variants diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index d82b4cbe..7486134c 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -120,7 +120,6 @@ def test_relu_metal(metal_kernel, dtype): # Repo only contains Torch 2.4 kernels (and we don't # support/test against this version). ("kernels-test/only-torch-2.4", "main", False), - ("google-bert/bert-base-uncased", "87565a309", False), ("kernels-test/flattened-build", "main", True), ("kernels-test/flattened-build", "without-compat-module", True), ], From 6bcf4e08c063d782b34f82e76d0a9cf1640cd489 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 11:37:45 +0000 Subject: [PATCH 8/9] Add tvm-ffi variant strings for testing --- kernels/tests/test_variants.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py index 47a3bdfe..b0d17616 100644 --- a/kernels/tests/test_variants.py +++ b/kernels/tests/test_variants.py @@ -36,6 +36,12 @@ "torch210-cxx11-xpu20253-x86_64-linux", "torch210-metal-aarch64-darwin", "torch210-xpu20253-x86_64-windows", + "tvm-ffi01-cpu-x86_64-linux", + "tvm-ffi01-cu126-x86_64-linux", + "tvm-ffi01-cu128-x86_64-linux", + "tvm-ffi01-cu130-x86_64-linux", + "tvm-ffi01-metal-aarch64-darwin", + "tvm-ffi01-xpu20253-x86_64-linux", ] From d4cfe1ff4ccfdd0191f6260a8764053da5f46da6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Fri, 13 Mar 2026 11:39:50 +0000 Subject: [PATCH 9/9] Formatting --- kernels/src/kernels/variants.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 92de7911..c40d6e69 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -166,8 +166,6 @@ def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: logging.warning( f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}" ) - log - pass return variants