diff --git a/docs/source/cli-versions.md b/docs/source/cli-versions.md index fcbcdb49..d8c7b80a 100644 --- a/docs/source/cli-versions.md +++ b/docs/source/cli-versions.md @@ -1,6 +1,7 @@ # kernels versions -Use `kernels versions` to list all available versions of a kernel on the Hub. +Use `kernels versions` to list all available versions of a kernel on the Hub +and marks compatible versions. ## Usage @@ -19,7 +20,7 @@ kernels versions kernels-community/activation ## Example Output ```text -Version 1: torch210-cu128-x86_64-windows, torch210-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch210-metal-aarch64-darwin ✅, torch27-cxx11-cu118-x86_64-linux, torch27-cxx11-cu126-x86_64-linux, torch27-cxx11-cu128-aarch64-linux, torch27-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu126-x86_64-linux, torch28-cxx11-cu128-aarch64-linux, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu129-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, torch29-cxx11-cu126-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu128-x86_64-linux, torch29-cxx11-cu130-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch29-metal-aarch64-darwin +Version 1: torch210-metal-aarch64-darwin, torch28-cxx11-cu126-aarch64-linux, torch28-cxx11-cu129-aarch64-linux, torch28-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-x86_64-linux, torch27-cxx11-cu118-x86_64-linux, torch210-cxx11-cu130-x86_64-linux, torch29-cxx11-cu128-aarch64-linux, torch29-cxx11-cu130-aarch64-linux, torch27-cxx11-cu126-x86_64-linux, ✅ torch29-cxx11-cu126-x86_64-linux (compatible), torch27-cxx11-cu128-x86_64-linux, torch210-cxx11-cu126-x86_64-linux, torch29-metal-aarch64-darwin, torch27-cxx11-cu128-aarch64-linux, torch210-cu128-x86_64-windows, torch28-cxx11-cu128-x86_64-linux, torch28-cxx11-cu126-x86_64-linux, torch210-cxx11-cu128-x86_64-linux, torch29-cxx11-cu126-aarch64-linux, ✅ torch29-cxx11-cu128-x86_64-linux (preferred), torch28-cxx11-cu129-x86_64-linux ``` ## Use Cases diff --git a/kernels/src/kernels/backends.py b/kernels/src/kernels/backends.py index 51b322aa..cc14ba5e 100644 --- a/kernels/src/kernels/backends.py +++ b/kernels/src/kernels/backends.py @@ -1,8 +1,9 @@ import ctypes import ctypes.util +import re import warnings from dataclasses import dataclass -from typing import Optional, Protocol +from typing import ClassVar, Optional, Protocol from packaging.version import Version @@ -18,7 +19,7 @@ def name(self) -> str: ... @property - def variant(self) -> str: + def variant_str(self) -> str: """ The name of the backend as used in a build variant, e.g. `cu128` for CUDA 12.8. @@ -26,8 +27,10 @@ def variant(self) -> str: ... -@dataclass +@dataclass(unsafe_hash=True) class CANN: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cann(\d+)(\d+)") + version: Version @property @@ -35,23 +38,38 @@ def name(self) -> str: return "cann" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"cann{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "CANN": + m = CANN._VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid CANN variant string: {s!r}") + return CANN(version=Version(f"{m.group(1)}.{m.group(2)}")) + -@dataclass +@dataclass(unsafe_hash=True) class CPU: @property def name(self) -> str: return "cpu" @property - def variant(self) -> str: + def variant_str(self) -> str: return "cpu" + @staticmethod + def parse(s: str) -> "CPU": + if s != "cpu": + raise ValueError(f"Invalid CPU variant string: {s!r}") + return CPU() + -@dataclass +@dataclass(unsafe_hash=True) class CUDA: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"cu(\d+)(\d+)") + version: Version @property @@ -59,34 +77,55 @@ def name(self) -> str: return "cuda" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"cu{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "CUDA": + m = CUDA._VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid CUDA variant string: {s!r}") + return CUDA(version=Version(f"{m.group(1)}.{m.group(2)}")) + -@dataclass +@dataclass(unsafe_hash=True) class Metal: @property def name(self) -> str: return "metal" @property - def variant(self) -> str: + def variant_str(self) -> str: return "metal" + @staticmethod + def parse(s: str) -> "Metal": + if s != "metal": + raise ValueError(f"Invalid Metal variant string: {s!r}") + return Metal() + -@dataclass +@dataclass(unsafe_hash=True) class Neuron: @property def name(self) -> str: return "neuron" @property - def variant(self) -> str: + def variant_str(self) -> str: return "neuron" + @staticmethod + def parse(s: str) -> "Neuron": + if s != "neuron": + raise ValueError(f"Invalid Neuron variant string: {s!r}") + return Neuron() -@dataclass + +@dataclass(unsafe_hash=True) class ROCm: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"rocm(\d+)(\d+)") + version: Version @property @@ -94,12 +133,21 @@ def name(self) -> str: return "rocm" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"rocm{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "ROCm": + m = ROCm._VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid ROCm variant string: {s!r}") + return ROCm(version=Version(f"{m.group(1)}.{m.group(2)}")) + -@dataclass +@dataclass(unsafe_hash=True) class XPU: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"xpu(\d+)(\d+)") + version: Version @property @@ -107,9 +155,36 @@ def name(self) -> str: return "xpu" @property - def variant(self) -> str: + def variant_str(self) -> str: return f"xpu{self.version.major}{self.version.minor}" + @staticmethod + def parse(s: str) -> "XPU": + m = XPU._VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid XPU variant string: {s!r}") + return XPU(version=Version(f"{m.group(1)}.{m.group(2)}")) + + +def parse_backend(s: str) -> Backend: + """Parse a backend variant string (e.g. 'cu128', 'rocm61', 'cpu') into a Backend.""" + if s == "cpu": + return CPU.parse(s) + elif s == "metal": + return Metal.parse(s) + elif s == "neuron": + return Neuron.parse(s) + elif s.startswith("cu"): + return CUDA.parse(s) + elif s.startswith("rocm"): + return ROCm.parse(s) + elif s.startswith("xpu"): + return XPU.parse(s) + elif s.startswith("cann"): + return CANN.parse(s) + else: + raise ValueError(f"Unknown backend variant string: {s!r}") + def _backend() -> Backend: if has_torch: diff --git a/kernels/src/kernels/cli/versions.py b/kernels/src/kernels/cli/versions.py index a752d789..daa4656c 100644 --- a/kernels/src/kernels/cli/versions.py +++ b/kernels/src/kernels/cli/versions.py @@ -1,44 +1,31 @@ -from importlib.util import find_spec -from pathlib import Path - -from huggingface_hub import HfApi - from kernels._versions import _get_available_versions -from kernels.utils import _build_variants, _get_hf_api -from kernels.variants import BUILD_VARIANT_REGEX +from kernels.utils import _get_hf_api +from kernels.variants import ( + get_variants, + resolve_variants, +) def print_kernel_versions(repo_id: str): api = _get_hf_api() - if find_spec("torch") is None: - # Do not mark compatible variants when Torch is not available. - compatible_variants = set() - else: - compatible_variants = set(_build_variants(None)) - versions = _get_available_versions(repo_id).items() if not versions: print(f"Repository does not support kernel versions: {repo_id}") return for version, ref in sorted(versions, key=lambda x: x[0]): + variants = get_variants(api, repo_id=repo_id, revision=ref.ref) + resolved = resolve_variants(variants, None) + best = resolved[0] if resolved else None + resolved_set = set(resolved) print(f"Version {version}: ", end="") - variants = [ - f"{variant} ✅" if variant in compatible_variants else f"{variant}" - for variant in _get_build_variants(api, repo_id, ref.ref) + variant_strs = [ + ( + f"✅ {variant.variant_str} ({'compatible, preferred' if variant == best else 'compatible'})" + if variant in resolved_set + else f"{variant.variant_str}" + ) + for variant in variants ] - print(", ".join(variants)) - - -def _get_build_variants(api: "HfApi", repo_id: str, revision: str) -> list[str]: - variants = set() - for filename in api.list_repo_files(repo_id, revision=revision): - path = Path(filename) - if len(path.parts) < 2 or path.parts[0] != "build": - continue - - match = BUILD_VARIANT_REGEX.match(path.parts[1]) - if match: - variants.add(path.parts[1]) - return sorted(variants) + print(", ".join(variant_strs)) diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 56c166f0..729510a7 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -17,13 +17,18 @@ from kernels._system import glibc_version from kernels._versions import select_revision_or_version -from kernels.backends import _backend +from kernels.backends import _backend, _select_backend from kernels.compat import has_torch, has_tvm_ffi from kernels.deps import validate_dependencies from kernels.lockfile import KernelLock, VariantLock from kernels.metadata import Metadata from kernels.status import resolve_status -from kernels.variants import _build_variants +from kernels.variants import ( + Variant, + get_variants, + get_variants_local, + resolve_variant, +) KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"} @@ -128,7 +133,16 @@ def install_kernel( repo_id, revision = resolve_status(api, repo_id, revision) package_name = package_name_from_repo_id(repo_id) - allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)] + + variants = get_variants(api, repo_id=repo_id, revision=revision) + variant = resolve_variant(variants, backend) + + if variant is None: + raise FileNotFoundError( + f"Cannot find a build variant for this system in {repo_id} (revision: {revision}). Available variants: {', '.join([variant.variant_str for variant in variants])}" + ) + + allow_patterns = [f"build/{variant.variant_str}/*"] repo_path = Path( str( api.snapshot_download( @@ -143,7 +157,10 @@ def install_kernel( try: return _find_kernel_in_repo_path( - repo_path, package_name, backend=backend, variant_locks=variant_locks + repo_path, + package_name, + variant=variant, + variant_locks=variant_locks, ) except FileNotFoundError: raise FileNotFoundError( @@ -155,30 +172,21 @@ def _find_kernel_in_repo_path( repo_path: Path, package_name: str, *, - backend: str | None = None, + variant: Variant, variant_locks: dict[str, VariantLock] | None = None, ) -> tuple[str, Path]: - variants = _build_variants(backend) - variant = None - variant_path = None - for candidate_variant in variants: - variant_path = repo_path / "build" / candidate_variant - if variant_path.exists(): - variant = candidate_variant - break - - if variant is None: - raise FileNotFoundError( - f"Kernel at path `{repo_path}` does not have one of build variants: {', '.join(variants)}" - ) - - assert variant_path is not None + variant_str = variant.variant_str + variant_path = repo_path / "build" / variant_str + if not variant_path.exists(): + raise FileNotFoundError(f"Variant path does not exist: `{variant_path}`") if variant_locks is not None: - variant_lock = variant_locks.get(variant) + variant_lock = variant_locks.get(variant_str) if variant_lock is None: raise ValueError(f"No lock found for build variant: {variant}") - validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash) + validate_kernel( + repo_path=repo_path, variant=variant_str, hash=variant_lock.hash + ) module_init_path = variant_path / "__init__.py" if not os.path.exists(module_init_path): @@ -297,12 +305,12 @@ def get_local_kernel( Returns: `ModuleType`: The imported kernel module. """ - # Presume we were given the top level path of the kernel repository. for base_path in [repo_path, repo_path / "build"]: - for v in _build_variants(backend): - variant_path = base_path / v - if variant_path.exists(): - return _import_from_path(package_name, variant_path) + variants = get_variants_local(base_path) + variant = resolve_variant(variants, backend) + + if variant is not None: + return _import_from_path(package_name, base_path / variant.variant_str) # If we didn't find the package in the repo we may have a explicit # package path. @@ -341,12 +349,19 @@ def has_kernel( package_name = package_name_from_repo_id(repo_id) api = _get_hf_api() - for variant in _build_variants(backend): - for init_file in ["__init__.py", f"{package_name}/__init__.py"]: - if api.file_exists( - repo_id, revision=revision, filename=f"build/{variant}/{init_file}" - ): - return True + variants = get_variants(api, repo_id=repo_id, revision=revision) + variant = resolve_variant(variants, backend) + + if variant is None: + return False + + for init_file in ["__init__.py", f"{package_name}/__init__.py"]: + if api.file_exists( + repo_id, + revision=revision, + filename=f"build/{variant.variant_str}/{init_file}", + ): + return True return False @@ -388,7 +403,15 @@ def load_kernel( package_name = package_name_from_repo_id(repo_id) api = _get_hf_api() - allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)] + variants = get_variants(api, repo_id=repo_id, revision=locked_sha) + variant = resolve_variant(variants, backend) + + if variant is None: + raise FileNotFoundError( + f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}). Available variants: {', '.join([variant.variant_str for variant in variants])}" + ) + + allow_patterns = [f"build/{variant.variant_str}/*"] repo_path = Path( str( api.snapshot_download( @@ -403,7 +426,10 @@ def load_kernel( try: package_name, variant_path = _find_kernel_in_repo_path( - repo_path, package_name, backend=backend, variant_locks=None + repo_path, + package_name, + variant=variant, + variant_locks=None, ) return _import_from_path(package_name, variant_path) except FileNotFoundError: @@ -538,6 +564,18 @@ def package_name_from_repo_id(repo_id: str) -> str: return repo_id.split("/")[-1].replace("-", "_") +def _platform() -> str: + cpu = platform.machine() + os = platform.system().lower() + + if os == "darwin": + cpu = "aarch64" if cpu == "arm64" else cpu + elif os == "windows": + cpu = "x86_64" if cpu == "AMD64" else cpu + + return f"{cpu}-{os}" + + def _get_hf_api(user_agent: str | dict | None = None) -> HfApi: """Returns an instance of HfApi with proper settings.""" @@ -553,8 +591,8 @@ def _get_hf_api(user_agent: str | dict | None = None) -> HfApi: # System info python = ".".join(platform.python_version_tuple()[:2]) - variants = ":".join(_build_variants(None)) - user_agent_str += f"; kernels/{__version__}; python/{python}; build_variant/{variants}; file_type/kernel" + backend = _select_backend(None).variant_str + user_agent_str += f"; kernels/{__version__}; python/{python}; backend/{backend}; flatform/{_platform()}; file_type/kernel" if has_torch: import torch diff --git a/kernels/src/kernels/variants.py b/kernels/src/kernels/variants.py index 33f6ee76..c40d6e69 100644 --- a/kernels/src/kernels/variants.py +++ b/kernels/src/kernels/variants.py @@ -1,9 +1,23 @@ +import logging import platform import re +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar -from packaging.version import parse +from huggingface_hub import HfApi +from huggingface_hub.hf_api import RepoFolder +from packaging.version import Version, parse -from kernels.backends import _select_backend +from kernels.backends import ( + CANN, + CUDA, + XPU, + Backend, + ROCm, + _select_backend, + parse_backend, +) from kernels.compat import has_torch, has_tvm_ffi BUILD_VARIANT_REGEX = re.compile( @@ -11,75 +25,318 @@ ) -def _torch_build_variant(backend: str | None) -> list[str]: - if not has_torch: +@dataclass(unsafe_hash=True) +class Torch: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"torch(\d+?)(\d+)") + + version: Version | None + + @property + def variant_str(self) -> str: + if self.version is None: + return "torch" + return f"torch{self.version.major}{self.version.minor}" + + @staticmethod + def parse(s: str) -> "Torch": + if s == "torch": + return Torch(version=None) + m = Torch._VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid Torch variant string: {s!r}") + return Torch(version=Version(f"{m.group(1)}.{m.group(2)}")) + + +@dataclass(unsafe_hash=True) +class TvmFfi: + _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile(r"tvm-ffi(\d+?)(\d+)") + + version: Version + + @property + def variant_str(self) -> str: + return f"tvm-ffi{self.version.major}{self.version.minor}" + + @staticmethod + def parse(s: str) -> "TvmFfi": + m = TvmFfi._VARIANT_REGEX.fullmatch(s) + if not m: + raise ValueError(f"Invalid TvmFfi variant string: {s!r}") + return TvmFfi(version=Version(f"{m.group(1)}.{m.group(2)}")) + + +@dataclass(unsafe_hash=True) +class Arch: + """Aarch kernel information.""" + + backend: Backend + platform: str + os: str + cxx11_abi: bool | None + + @property + def variant_str(self) -> str: + if self.cxx11_abi is None: + return f"{self.backend.variant_str}-{self.platform}-{self.os}" + else: + return f"{'cxx11' if self.cxx11_abi else 'cxx98'}-{self.backend.variant_str}-{self.platform}-{self.os}" + + @staticmethod + def parse(parts: list[str]) -> "Arch": + # Handle Linux with cxx11 marker. + if len(parts) == 4: + # In the future, we want to remove the marker and use cxx11 as + # the default. We check on cxx98 for this reason. + cxx11_abi = parts[0] != "cxx98" + parts = parts[1:] + elif len(parts) == 3: + cxx11_abi = None + else: + raise ValueError(f"Invalid arch variant parts: {parts!r}") + + backend = parse_backend(parts[0]) + platform = parts[1] + os = parts[2] + + return Arch(backend=backend, platform=platform, os=os, cxx11_abi=cxx11_abi) + + +@dataclass(unsafe_hash=True) +class Noarch: + """Noarch kernel information.""" + + backend_name: str + + @property + def variant_str(self) -> str: + return self.backend_name + + @staticmethod + def parse(s: str) -> "Noarch": + return Noarch(backend_name=s) + + +@dataclass(unsafe_hash=True) +class Variant: + """Kernel build variant.""" + + framework: Torch | TvmFfi + arch: Arch | Noarch + + @property + def variant_str(self) -> str: + return f"{self.framework.variant_str}-{self.arch.variant_str}" + + @staticmethod + def parse(variant_str: str) -> "Variant": + parts = variant_str.split("-") + + arch: Arch | Noarch + framework: Torch | TvmFfi + + if parts[0] == "torch": + # noarch: e.g. "torch-cpu" + framework = Torch.parse(parts[0]) + arch = Noarch.parse("-".join(parts[1:])) + elif parts[0].startswith("torch"): + framework = Torch.parse(parts[0]) + arch = Arch.parse(parts[1:]) + elif parts[0] == "tvm" and parts[1].startswith("ffi"): + framework = TvmFfi.parse(f"tvm-{parts[1]}") + arch = Arch.parse(parts[2:]) + else: + raise ValueError(f"Unknown framework in variant string: {variant_str!r}") + + return Variant(framework=framework, arch=arch) + + +def get_variants(api: HfApi, *, repo_id: str, revision: str) -> list[Variant]: + """Get all the build variants available from a kernel repository.""" + + tree = api.list_repo_tree(repo_id, path_in_repo="build", revision=revision) + variant_strs = { + item.path.split("/")[-1] for item in tree if isinstance(item, RepoFolder) + } + + variants = [] + for variant_str in variant_strs: + try: + variants.append(Variant.parse(variant_str)) + except ValueError: + logging.warning( + f"Repository {repo_id} (revision: {revision}) contains invalid build variant variant: {variant_str!r}" + ) + return variants + + +def get_variants_local(repo_path: Path) -> list[Variant]: + """Get all the build variants available in a local directory.""" + + try: + variant_strs = {entry.name for entry in repo_path.iterdir() if entry.is_dir()} + except Exception: return [] - selected_backend = _select_backend(backend) + variants = [] + for variant_str in variant_strs: + try: + variants.append(Variant.parse(variant_str)) + except ValueError: + pass + return variants - backend_variant = selected_backend.variant - import torch +def resolve_variant( + variants: list[Variant], backend: str | None = None +) -> Variant | None: + """Return the best matching variant for the current system.""" + resolved = resolve_variants(variants, backend) + return resolved[0] if resolved else None + + +def resolve_variants( + variants: list[Variant], backend: str | None = None +) -> list[Variant]: + """Return the matching variants for the current system, sorted + by decreasing order of preference.""" + selected_backend = _select_backend(backend) - torch_version = parse(torch.__version__) cpu = platform.machine() os = platform.system().lower() if os == "darwin": cpu = "aarch64" if cpu == "arm64" else cpu - return [ - f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}" - ] elif os == "windows": cpu = "x86_64" if cpu == "AMD64" else cpu - return [ - f"torch{torch_version.major}{torch_version.minor}-{backend_variant}-{cpu}-{os}" - ] - cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" - return [ - f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{backend_variant}-{cpu}-{os}" - ] + torch_version = None + torch_cxx11_abi = None + if has_torch: + import torch + # Parse Torch version and strip patch/tags. + torch_version = parse(torch.__version__) + torch_version = Version(f"{torch_version.major}.{torch_version.minor}") -def _tvm_ffi_build_variant(backend: str | None) -> list[str]: - if not has_tvm_ffi: - return [] + torch_cxx11_abi = torch.compiled_with_cxx11_abi() if os == "linux" else None - selected_backend = _select_backend(backend) + tvm_ffi_version = None + if has_tvm_ffi: + import tvm_ffi - backend_variant = selected_backend.variant + # Parse Torch version and strip patch/tags. + tvm_ffi_version = parse(tvm_ffi.__version__) + tvm_ffi_version = Version(f"{tvm_ffi_version.major}.{tvm_ffi_version.minor}") - import tvm_ffi + return _resolve_variant_for_system( + variants=variants, + selected_backend=selected_backend, + cpu=cpu, + os=os, + torch_version=torch_version, + torch_cxx11_abi=torch_cxx11_abi, + tvm_ffi_version=tvm_ffi_version, + ) - tvm_ffi_version = parse(tvm_ffi.__version__) - cpu = platform.machine() - os = platform.system().lower() - return [ - f"tvm-ffi{tvm_ffi_version.major}{tvm_ffi_version.minor}-{backend_variant}-{cpu}-{os}" - ] +def _resolve_variant_for_system( + variants: list[Variant], + selected_backend: Backend, + cpu: str, + os: str, + torch_version: Version | None, + torch_cxx11_abi: bool | None, + tvm_ffi_version: Version | None, +) -> list[Variant]: + """Resolve the best matching variant given explicit system parameters.""" + applicable = _filter_variants( + variants, + selected_backend, + cpu, + os, + torch_version, + torch_cxx11_abi, + tvm_ffi_version, + ) + return _sort_variants(applicable) -def _build_variant_noarch(backend: str | None) -> list[str]: - selected_backend = _select_backend(backend) +def _filter_variants( + variants: list[Variant], + selected_backend: Backend, + cpu: str, + os: str, + torch_version: Version | None, + torch_cxx11_abi: bool | None, + tvm_ffi_version: Version | None, +) -> list[Variant]: + """Return only the variants applicable to the current system.""" + result = [] + for v in variants: + if isinstance(v.arch, Arch): + # Skip non-matching CPU or OS. + if v.arch.platform != cpu or v.arch.os != os: + continue + # If the variant is a Torch or tvm-ffi variant, check that it has the + # correct version and ABI. + if isinstance(v.framework, Torch): + if v.framework.version != torch_version: + continue + if v.arch.cxx11_abi != torch_cxx11_abi: + continue + elif isinstance(v.framework, TvmFfi): + if v.framework.version != tvm_ffi_version: + continue + # Given a system CUDA version of x.y, only CUDA versions x.z, + # where z <= y qualify. Otherwise, the backend + version (if present) + # must match. + if isinstance(selected_backend, CUDA) and isinstance(v.arch.backend, CUDA): + if ( + v.arch.backend.version.major != selected_backend.version.major + or v.arch.backend.version.minor > selected_backend.version.minor + ): + continue + elif v.arch.backend.variant_str != selected_backend.variant_str: + continue + else: + assert isinstance(v.arch, Noarch) + # Only noarch variants with a matching backend or "universal" + # are applicable. + noarch_backend_name = ( + "npu" if selected_backend.name == "cann" else selected_backend.name + ) + if ( + v.arch.backend_name != noarch_backend_name + and v.arch.backend_name != "universal" + ): + continue + result.append(v) + return result - if selected_backend.name == "cann": - return ["torch-npu"] - else: - return [f"torch-{selected_backend.name}"] +def _sort_variants( + variants: list[Variant], +) -> list[Variant]: + """Sort variants in preference order: -def _build_variant_universal() -> list[str]: - # Once we support other frameworks, detection goes here. - return ["torch-universal"] if has_torch else [] + 1. Torch arch kernels with with the highest compatible CUDA version. + 2. tvm-ffi arch kernels with with the highest compatible CUDA version. + 3. Torch noarch kernels. + 4. Old Torch universal kernels. + """ + def sort_key(v: Variant) -> tuple: + if isinstance(v.arch, Arch): + framework_order = 0 if isinstance(v.framework, Torch) else 1 + if isinstance(v.arch.backend, (CUDA, ROCm, XPU, CANN)): + # Order by backend version in reverse (higher is better). + backend_order = -v.arch.backend.version.minor + else: + backend_order = 0 + return (framework_order, backend_order) + else: + assert isinstance(v.arch, Noarch) + universal_order = 1 if v.arch.backend_name == "universal" else 0 + return (2, universal_order) -def _build_variants(backend: str | None) -> list[str]: - """Return compatible build variants in preferred order.""" - return [ - *_torch_build_variant(backend), - *_tvm_ffi_build_variant(backend), - *_build_variant_noarch(backend), - *_build_variant_universal(), - ] + return sorted(variants, key=sort_key) diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index d82b4cbe..7486134c 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -120,7 +120,6 @@ def test_relu_metal(metal_kernel, dtype): # Repo only contains Torch 2.4 kernels (and we don't # support/test against this version). ("kernels-test/only-torch-2.4", "main", False), - ("google-bert/bert-base-uncased", "87565a309", False), ("kernels-test/flattened-build", "main", True), ("kernels-test/flattened-build", "without-compat-module", True), ], diff --git a/kernels/tests/test_variants.py b/kernels/tests/test_variants.py new file mode 100644 index 00000000..b0d17616 --- /dev/null +++ b/kernels/tests/test_variants.py @@ -0,0 +1,361 @@ +import pytest +from huggingface_hub import HfApi +from packaging.version import Version + +from kernels.backends import CPU, CUDA, ROCm +from kernels.variants import Variant, _resolve_variant_for_system, get_variants + +VARIANT_STRINGS = [ + "torch25-cxx98-cu118-aarch64-linux", + "torch25-cxx98-cpu-x86_64-linux", + "torch29-cpu-aarch64-darwin", + "torch29-cxx11-cpu-aarch64-linux", + "torch29-cxx11-cpu-x86_64-linux", + "torch29-cxx11-cu126-aarch64-linux", + "torch29-cxx11-cu126-x86_64-linux", + "torch29-cxx11-cu128-aarch64-linux", + "torch29-cxx11-cu128-x86_64-linux", + "torch29-cxx11-cu130-aarch64-linux", + "torch29-cxx11-cu130-x86_64-linux", + "torch29-cxx11-rocm63-x86_64-linux", + "torch29-cxx11-rocm64-x86_64-linux", + "torch29-cxx11-xpu20252-x86_64-linux", + "torch29-metal-aarch64-darwin", + "torch210-cpu-aarch64-darwin", + "torch210-cu128-x86_64-windows", + "torch210-cxx11-cpu-aarch64-linux", + "torch210-cxx11-cpu-x86_64-linux", + "torch210-cxx11-cu126-aarch64-linux", + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu128-aarch64-linux", + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu130-aarch64-linux", + "torch210-cxx11-cu130-x86_64-linux", + "torch210-cxx11-rocm70-x86_64-linux", + "torch210-cxx11-rocm71-x86_64-linux", + "torch210-cxx11-xpu20253-x86_64-linux", + "torch210-metal-aarch64-darwin", + "torch210-xpu20253-x86_64-windows", + "tvm-ffi01-cpu-x86_64-linux", + "tvm-ffi01-cu126-x86_64-linux", + "tvm-ffi01-cu128-x86_64-linux", + "tvm-ffi01-cu130-x86_64-linux", + "tvm-ffi01-metal-aarch64-darwin", + "tvm-ffi01-xpu20253-x86_64-linux", +] + + +NOARCH_VARIANT_STRINGS = [ + "torch-cpu", + "torch-cuda", + "torch-metal", + "torch-neuron", + "torch-rocm", + "torch-xpu", + "torch-npu", + "torch-universal", +] + +SUPERSET_VARIANT_STRINGS = [ + "torch29-cpu-aarch64-darwin", + "torch29-cxx11-cpu-aarch64-linux", + "torch29-cxx11-cpu-x86_64-linux", + "torch29-cxx11-cu126-aarch64-linux", + "torch29-cxx11-cu126-x86_64-linux", + "torch29-cxx11-cu128-aarch64-linux", + "torch29-cxx11-cu128-x86_64-linux", + "torch29-cxx11-cu130-aarch64-linux", + "torch29-cxx11-cu130-x86_64-linux", + "torch29-cxx11-rocm63-x86_64-linux", + "torch29-cxx11-rocm64-x86_64-linux", + "torch29-cxx11-xpu20252-x86_64-linux", + "torch29-metal-aarch64-darwin", + "torch210-cpu-aarch64-darwin", + "torch210-cu128-x86_64-windows", + "torch210-cxx11-cpu-aarch64-linux", + "torch210-cxx11-cpu-x86_64-linux", + "torch210-cxx11-cu126-aarch64-linux", + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu128-aarch64-linux", + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu130-aarch64-linux", + "torch210-cxx11-cu130-x86_64-linux", + "torch210-cxx11-rocm70-x86_64-linux", + "torch210-cxx11-rocm71-x86_64-linux", + "torch210-cxx11-xpu20253-x86_64-linux", + "torch210-metal-aarch64-darwin", + "torch210-xpu20253-x86_64-windows", +] + + +@pytest.mark.parametrize("variant_str", VARIANT_STRINGS) +def test_arch_variants(variant_str: str): + # Roundtrip parse and generate variant string. + assert Variant.parse(variant_str).variant_str == variant_str + + +@pytest.mark.parametrize("variant_str", NOARCH_VARIANT_STRINGS) +def test_noarch_variants(variant_str: str): + # Roundtrip parse and generate variant string. + assert Variant.parse(variant_str).variant_str == variant_str + + +def test_get_variants(): + api = HfApi() + variants = get_variants(api, repo_id="kernels-community/relu", revision="v1") + variant_strs = {v.variant_str for v in variants} + # Superset because new variants may be added in the future. + assert variant_strs.issuperset(SUPERSET_VARIANT_STRINGS) + + +RESOLVE_VARIANTS = [ + Variant.parse(s) + for s in [ + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu130-x86_64-linux", + "torch210-cxx11-rocm70-x86_64-linux", + "torch210-cxx11-cpu-x86_64-linux", + "torch210-cpu-aarch64-darwin", + "torch210-metal-aarch64-darwin", + "torch-cuda", + "torch-cpu", + ] +] + + +def test_resolve_cuda_exact(): + # CUDA 12.8 should resolve to cu128. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cu128-x86_64-linux" + + +def test_resolve_cuda_best_older_minor(): + # CUDA 12.9 is not available, should fall back to cu128 (highest <= 12.9). + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.9")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cu128-x86_64-linux" + + +def test_resolve_cuda_no_newer_minor(): + # CUDA 12.5 is older than all the variants, fall back to noarch. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.5")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch-cuda" + + +def test_resolve_cuda_no_different_major(): + # Different major version must not match. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("11.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch-cuda" + + +def test_resolve_rocm(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=ROCm(Version("7.0")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cxx11-rocm70-x86_64-linux" + + +def test_resolve_cpu_linux(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CPU(), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cpu-x86_64-linux" + + +def test_resolve_cpu_darwin(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CPU(), + cpu="aarch64", + os="darwin", + torch_version=Version("2.10"), + torch_cxx11_abi=None, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cpu-aarch64-darwin" + + +def test_resolve_metal_darwin(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CPU(), + cpu="aarch64", + os="darwin", + torch_version=Version("2.10"), + torch_cxx11_abi=None, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cpu-aarch64-darwin" + + +def test_resolve_noarch_fallback(): + # With no matching arch variant, should fall back to torch noarch. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=CUDA(Version("12.8")), + cpu="aarch64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch-cuda" + + +def test_resolve_no_match(): + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS, + selected_backend=ROCm(Version("7.0")), + cpu="x86_64", + os="linux", + torch_version=Version("2.9"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result == [] + + +RESOLVE_VARIANTS_UNIVERSAL = [ + Variant.parse(s) + for s in [ + "torch210-cxx11-cu128-x86_64-linux", + "torch-universal", + ] +] + + +def test_resolve_universal_matches_any_backend(): + # Universal works with every backend. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_UNIVERSAL, + selected_backend=ROCm(Version("7.0")), + cpu="x86_64", + os="linux", + torch_version=Version("2.9"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch-universal" + + +def test_resolve_universal_is_last_resort(): + # Specific match is preferred over universal. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_UNIVERSAL, + selected_backend=CUDA(Version("12.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch210-cxx11-cu128-x86_64-linux" + + +def test_resolve_specific_noarch_preferred_over_universal(): + # Backend-specific noarch is preferred over universal. + variants = [Variant.parse(s) for s in ["torch-universal", "torch-cuda"]] + result = _resolve_variant_for_system( + variants=variants, + selected_backend=CUDA(Version("12.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.9"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result != [] + assert result[0].variant_str == "torch-cuda" + + +RESOLVE_VARIANTS_NO_NOARCH = [ + Variant.parse(s) + for s in [ + "torch210-cxx11-cu126-x86_64-linux", + "torch210-cxx11-cu128-x86_64-linux", + "torch210-cxx11-cu130-x86_64-linux", + ] +] + + +def test_resolve_cuda_no_newer_minor_no_noarch(): + # No compatible variant for 12.5. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_NO_NOARCH, + selected_backend=CUDA(Version("12.5")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result == [] + + +def test_resolve_cuda_no_different_major_no_noarch(): + # 11.8 has a different major, so there is no compatible fallback. + result = _resolve_variant_for_system( + variants=RESOLVE_VARIANTS_NO_NOARCH, + selected_backend=CUDA(Version("11.8")), + cpu="x86_64", + os="linux", + torch_version=Version("2.10"), + torch_cxx11_abi=True, + tvm_ffi_version=None, + ) + assert result == []