From 966c9b4f8d2298d43668e18e78d5f811ac40beb8 Mon Sep 17 00:00:00 2001 From: Xiaodong Ye Date: Sun, 17 May 2026 13:26:20 +0800 Subject: [PATCH] Refactor torchada patching into focused compatibility modules Signed-off-by: Xiaodong Ye --- README.md | 26 +- README_CN.md | 21 +- docs/compat_gap_cuda_introspection.md | 38 + docs/compat_gap_cuda_nccl_attr.md | 27 + docs/compat_gap_cuda_public_aliases.md | 44 + docs/compat_gap_inventory.md | 62 ++ src/torchada/_accelerator_compat.py | 168 ++++ src/torchada/_cpp_ops.py | 150 ++-- src/torchada/_ctypes_compat.py | 81 ++ src/torchada/_cuda_compat.py | 353 ++++++++ src/torchada/_device_compat.py | 301 +++++++ src/torchada/_mapping.py | 231 ++--- src/torchada/_patch.py | 1095 +++--------------------- src/torchada/_platform.py | 48 +- src/torchada/_runtime.py | 134 ++- src/torchada/csrc/musa_ops.mu | 24 +- src/torchada/csrc/ops.cpp | 16 +- src/torchada/csrc/ops.h | 21 +- src/torchada/cuda/__init__.py | 96 ++- src/torchada/cuda/amp.py | 14 +- src/torchada/cuda/nvtx.py | 8 +- src/torchada/cuda/random.py | 37 +- tests/test_cuda_patching.py | 436 ++++++++++ tests/test_mappings.py | 41 + 24 files changed, 2077 insertions(+), 1395 deletions(-) create mode 100644 docs/compat_gap_cuda_introspection.md create mode 100644 docs/compat_gap_cuda_nccl_attr.md create mode 100644 docs/compat_gap_cuda_public_aliases.md create mode 100644 docs/compat_gap_inventory.md create mode 100644 src/torchada/_accelerator_compat.py create mode 100644 src/torchada/_ctypes_compat.py create mode 100644 src/torchada/_cuda_compat.py create mode 100644 src/torchada/_device_compat.py diff --git a/README.md b/README.md index b8c67f4..86109da 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,21 @@ torchada is an adapter that makes [torch_musa](https://github.com/MooreThreads/t Many PyTorch projects are written for NVIDIA GPUs using `torch.cuda.*` APIs. To run these on Moore Threads GPUs, you would normally need to change every `cuda` reference to `musa`. torchada eliminates this by automatically translating CUDA API calls to MUSA equivalents at runtime. +## Architecture + +torchada sits between user code and the PyTorch MUSA backend. User applications import `torchada` once and continue calling standard `torch.cuda.*` APIs. The compatibility layer patches CUDA entry points, translates CUDA device references to MUSA, installs CUDA-shaped module shims, and ports CUDA extension sources and symbols for the MUSA toolchain. On MUSA platforms, calls are redirected to `torch.musa`, custom operators use the `PrivateUse1` dispatch key, distributed NCCL requests map to MCCL, and runtime calls target the MUSA runtime. + +Import-time setup follows a small sequence: + +- Detect the active platform (`MUSA`, native `CUDA`, or `CPU`). +- Load optional C++/MUSA operator overrides on MUSA platforms. +- Apply PyTorch compatibility patches through the `_patch.py` registry. +- Configure the bundled Triton/MoE defaults for SGLang and vLLM. + +The main compatibility modules are `_device_compat.py`, `_cuda_compat.py`, `_runtime.py`, `_ctypes_compat.py`, `_accelerator_compat.py`, `utils/cpp_extension.py`, `_mapping.py`, `_cpp_ops.py`, `csrc/`, `cuda/`, and `triton/`. + +For downstream compatibility, `torch.cuda.is_available()` and `torch.version.cuda` are intentionally left unpatched so projects can still distinguish native CUDA environments from MUSA environments. + ## Prerequisites - **torch_musa**: You must have [torch_musa](https://github.com/MooreThreads/torch_musa) installed (this provides MUSA support for PyTorch) @@ -53,9 +68,11 @@ That's it! All `torch.cuda.*` APIs are automatically redirected to `torch.musa.* | Device operations | `tensor.cuda()`, `model.cuda()`, `torch.device("cuda")` | | Memory management | `torch.cuda.memory_allocated()`, `empty_cache()` | | Synchronization | `torch.cuda.synchronize()`, `Stream`, `Event` | +| Compatibility aliases | `memory_cached()`, `torch.cuda.streams`, `torch.cuda.sparse` | | Mixed precision | `torch.cuda.amp.autocast()`, `GradScaler()` | | CUDA Graphs | `torch.cuda.CUDAGraph`, `torch.cuda.graph()` | | CUDA Runtime | `torch.cuda.cudart()` → uses MUSA runtime | +| CUDA Introspection | `get_gencode_flags()`, `get_sync_debug_mode()`, `set_sync_debug_mode()` | | Profiler | `ProfilerActivity.CUDA` → uses PrivateUse1 | | Custom Ops | `Library.impl(..., "CUDA")` → uses PrivateUse1 | | Distributed | `dist.init_process_group(backend='nccl')` → uses MCCL | @@ -199,11 +216,10 @@ with torch.accelerator.stream(torch.musa.Stream()): ... ``` -**Forward compatibility:** The wrapper always prefers the real -`torch.accelerator` implementation and only falls back to `torch.musa` when an -attribute is missing, so upgrading to a future PyTorch release that ships -official implementations requires no changes on your side — you will -automatically get the upstream version. +**Forward compatibility:** The wrapper first applies torchada overrides for +MUSA-specific fixes such as synchronization and memory APIs, then prefers the +real `torch.accelerator` implementation, and finally falls back to `torch.musa` +when an attribute is missing. ## Platform Detection diff --git a/README_CN.md b/README_CN.md index fcfb4ed..25ef885 100644 --- a/README_CN.md +++ b/README_CN.md @@ -16,6 +16,21 @@ torchada 是一个适配器,让 [torch_musa](https://github.com/MooreThreads/t 许多 PyTorch 项目使用 `torch.cuda.*` API 为 NVIDIA GPU 编写。要在摩尔线程 GPU 上运行这些项目,通常需要把每个 `cuda` 引用改成 `musa`。torchada 通过在运行时自动将 CUDA API 调用转换为 MUSA 等效调用来消除这一问题。 +## 架构 + +torchada 位于用户代码和 PyTorch MUSA 后端之间。用户应用只需导入一次 `torchada`,之后继续使用标准 `torch.cuda.*` API。兼容层会修补 CUDA 入口点,将 CUDA 设备引用转换为 MUSA,安装 CUDA 形态的兼容模块,并为 MUSA 工具链转换 CUDA 扩展源码和符号。在 MUSA 平台上,调用会重定向到 `torch.musa`,自定义算子使用 `PrivateUse1` 调度键,分布式 NCCL 请求映射到 MCCL,运行时调用则落到 MUSA 运行时。 + +导入时的初始化流程很短: + +- 检测当前平台(`MUSA`、原生 `CUDA` 或 `CPU`)。 +- 在 MUSA 平台上加载可选的 C++/MUSA 算子覆盖。 +- 通过 `_patch.py` 注册表应用 PyTorch 兼容性补丁。 +- 为 SGLang 和 vLLM 设置内置 Triton/MoE 默认配置。 + +主要兼容模块包括 `_device_compat.py`、`_cuda_compat.py`、`_runtime.py`、`_ctypes_compat.py`、`_accelerator_compat.py`、`utils/cpp_extension.py`、`_mapping.py`、`_cpp_ops.py`、`csrc/`、`cuda/` 和 `triton/`。 + +为了保持下游项目的平台检测逻辑,`torch.cuda.is_available()` 和 `torch.version.cuda` 会有意保持不修补,这样项目仍然可以区分原生 CUDA 环境和 MUSA 环境。 + ## 前置条件 - **torch_musa**:必须安装 [torch_musa](https://github.com/MooreThreads/torch_musa)(提供 PyTorch 的 MUSA 支持) @@ -53,9 +68,11 @@ torch.cuda.synchronize() | 设备操作 | `tensor.cuda()`, `model.cuda()`, `torch.device("cuda")` | | 显存管理 | `torch.cuda.memory_allocated()`, `empty_cache()` | | 同步 | `torch.cuda.synchronize()`, `Stream`, `Event` | +| 兼容别名 | `memory_cached()`、`torch.cuda.streams`、`torch.cuda.sparse` | | 混合精度 | `torch.cuda.amp.autocast()`, `GradScaler()` | | CUDA Graphs | `torch.cuda.CUDAGraph`, `torch.cuda.graph()` | | CUDA 运行时 | `torch.cuda.cudart()` → 使用 MUSA 运行时 | +| CUDA 自省/调试 | `get_gencode_flags()`、`get_sync_debug_mode()`、`set_sync_debug_mode()` | | 性能分析 | `ProfilerActivity.CUDA` → 使用 PrivateUse1 | | 自定义算子 | `Library.impl(..., "CUDA")` → 使用 PrivateUse1 | | 分布式训练 | `dist.init_process_group(backend='nccl')` → 使用 MCCL | @@ -197,8 +214,8 @@ with torch.accelerator.stream(torch.musa.Stream()): ... ``` -**前向兼容性:** 包装器始终优先使用真正的 `torch.accelerator` 实现,只有在缺少属性时才回退到 -`torch.musa`,因此升级到提供官方实现的未来 PyTorch 版本时无需任何更改 —— 您将自动获得上游版本。 +**前向兼容性:** 包装器会先应用 torchada 针对 MUSA 的修复(例如同步和显存 API),然后优先使用真正的 +`torch.accelerator` 实现,最后在属性缺失时回退到 `torch.musa`。 ## 平台检测 diff --git a/docs/compat_gap_cuda_introspection.md b/docs/compat_gap_cuda_introspection.md new file mode 100644 index 0000000..17572c1 --- /dev/null +++ b/docs/compat_gap_cuda_introspection.md @@ -0,0 +1,38 @@ +# CUDA Introspection Compatibility Gap + +## Status + +- Fixed in `src/torchada/_patch.py` +- Covered by `tests/test_cuda_patching.py::TestCudaBuildAndDebugIntrospection` + +## Gap + +In the `yeahdongcn1` torch_musa 2.7.1 container, these top-level CUDA APIs exist +on `torch.cuda` but are absent from `torch.musa`: + +- `torch.cuda.get_gencode_flags` +- `torch.cuda.get_sync_debug_mode` +- `torch.cuda.set_sync_debug_mode` + +After torchada redirects `torch.cuda` to `torch.musa`, those calls raised +`AttributeError` instead of preserving CUDA-compatible API access. + +## Fix + +torchada now installs MUSA-safe shims when torch_musa does not provide these +attributes: + +- `get_gencode_flags()` returns `""` because NVCC gencode flags are CUDA-specific + and should not be passed to the MUSA toolchain. +- `get_sync_debug_mode()` and `set_sync_debug_mode()` maintain a process-local + debug mode value so CUDA-oriented code can call the public API without + requiring unavailable CUDA C++ hooks. + +## Verification + +Run in the MUSA test container: + +```bash +docker exec -w /ws yeahdongcn1 python -m pytest \ + tests/test_cuda_patching.py::TestCudaBuildAndDebugIntrospection -v +``` diff --git a/docs/compat_gap_cuda_nccl_attr.md b/docs/compat_gap_cuda_nccl_attr.md new file mode 100644 index 0000000..9e9c5de --- /dev/null +++ b/docs/compat_gap_cuda_nccl_attr.md @@ -0,0 +1,27 @@ +# CUDA NCCL Module Attribute Compatibility Gap + +## Status + +- Fixed in `src/torchada/_patch.py` +- Covered by `tests/test_cuda_patching.py::TestNCCLModule::test_nccl_module_alias_available` + +## Gap + +CUDA exposes `torch.cuda.nccl` as both an importable module and a module +attribute. torchada already registered `torch.cuda.nccl` in `sys.modules`, but +plain attribute access still failed on MUSA because the CUDA wrapper redirected +`torch.cuda.nccl` to missing `torch.musa.nccl` instead of `torch.musa.mccl`. + +## Fix + +torchada now aliases `torch.musa.nccl` to `torch.musa.mccl` when MCCL is +available and also remaps `torch.cuda.nccl` attribute access to `mccl`. + +## Verification + +Run in the MUSA test container: + +```bash +docker exec -w /ws yeahdongcn1 python -m pytest \ + tests/test_cuda_patching.py::TestNCCLModule::test_nccl_module_alias_available -v +``` diff --git a/docs/compat_gap_cuda_public_aliases.md b/docs/compat_gap_cuda_public_aliases.md new file mode 100644 index 0000000..858b756 --- /dev/null +++ b/docs/compat_gap_cuda_public_aliases.md @@ -0,0 +1,44 @@ +# CUDA Public API Alias Compatibility Gap + +## Status + +- Fixed in `src/torchada/_patch.py` +- Covered by `tests/test_cuda_patching.py::TestCudaPublicApiAliases` + +## Gap + +A broader `dir(torch.cuda)` versus `dir(torch.musa)` comparison in the +`yeahdongcn1` torch_musa 2.7.1 container found additional CUDA public attributes +that are commonly used as imports or compatibility aliases but were missing +after torchada redirected `torch.cuda` to `torch.musa`. + +## Fix + +torchada now provides MUSA-backed or safe compatibility aliases for: + +- Deprecated memory aliases: `memory_cached`, `max_memory_cached` +- Host-memory stat APIs with no MUSA counters: `host_memory_stats`, + `host_memory_stats_as_nested_dict`, `reset_accumulated_host_memory_stats`, + `reset_peak_host_memory_stats` +- Static CUDA build flags: `has_half`, `has_magma` +- Top-level `CUDAPluggableAllocator` +- `torch.cuda.streams` +- `torch.cuda.sparse` +- `torch.cuda.init` +- `torch.cuda.default_generators` +- `torch.cuda.get_stream_from_external` + +## Deferred + +Allocator-control APIs such as `caching_allocator_enable` and telemetry APIs +such as `utilization` remain deferred because torch_musa has no equivalent +behavior in the tested build. + +## Verification + +Run in the MUSA test container: + +```bash +docker exec -w /ws yeahdongcn1 python -m pytest \ + tests/test_cuda_patching.py::TestCudaPublicApiAliases -v +``` diff --git a/docs/compat_gap_inventory.md b/docs/compat_gap_inventory.md new file mode 100644 index 0000000..679dca3 --- /dev/null +++ b/docs/compat_gap_inventory.md @@ -0,0 +1,62 @@ +# CUDA to MUSA Compatibility Gap Inventory + +## Method + +Compared selected `torch.cuda` attributes against `torch.musa` in the +`yeahdongcn1` torch_musa 2.7.1 container, then verified behavior after +`import torchada`. + +## Fixed + +- `torch.cuda.get_gencode_flags` +- `torch.cuda.get_sync_debug_mode` +- `torch.cuda.set_sync_debug_mode` +- `torch.cuda.nccl` +- Deprecated memory aliases and host-memory stat APIs +- `torch.cuda.streams` +- `torch.cuda.sparse` +- `torch.cuda.init` +- `torch.cuda.default_generators` +- `torch.cuda.get_stream_from_external` +- `torch.cuda.CUDAPluggableAllocator` + +See: + +- `docs/compat_gap_cuda_introspection.md` +- `docs/compat_gap_cuda_nccl_attr.md` +- `docs/compat_gap_cuda_public_aliases.md` + +## Deferred + +These remaining names are not patched as part of CUDA-to-MUSA runtime +compatibility: + +- Imported helper symbols from the CUDA Python module: `Any`, `Callable`, + `Optional`, `Union`, `cast`, `classproperty`, `importlib`, `lru_cache`, + `threading`, `traceback` +- CUDA-only internal classes or APIs with no MUSA object model equivalent in the + tested build: `CudaError`, `cudaStatus`, `DeferredCudaCallError`, `Device`, + `ComplexFloatStorage`, `ComplexDoubleStorage`, `jiterator` + +The following CUDA APIs are NVIDIA/NVML telemetry helpers and still have no +`torch.musa` equivalent in the tested torch_musa build: + +- `torch.cuda.list_gpu_processes` +- `torch.cuda.utilization` +- `torch.cuda.memory_usage` +- `torch.cuda.temperature` +- `torch.cuda.power_draw` +- `torch.cuda.clock_rate` +- `torch.cuda.device_memory_used` +- `torch.cuda.caching_allocator_alloc` +- `torch.cuda.caching_allocator_delete` +- `torch.cuda.caching_allocator_enable` +- `torch.cuda.get_per_process_memory_fraction` +- `torch.cuda.gds` +- `torch.cuda.tunable` + +In the same container, the original CUDA implementations depend on `pynvml` and +do not provide real values without NVIDIA NVML support. They are not patched in +this pass to avoid returning misleading MUSA telemetry. Allocator-control and +tunable/GDS APIs are also left unpatched because torch_musa does not expose an +equivalent behavior in this tested build. diff --git a/src/torchada/_accelerator_compat.py b/src/torchada/_accelerator_compat.py new file mode 100644 index 0000000..2f9e7f3 --- /dev/null +++ b/src/torchada/_accelerator_compat.py @@ -0,0 +1,168 @@ +""" +Compatibility wrapper for the evolving ``torch.accelerator`` API. + +This module provides MUSA-backed fallbacks for torch.accelerator APIs that are +missing or incomplete in supported PyTorch versions. +""" + +import sys +from types import ModuleType + +import torch + + +class _AcceleratorModuleWrapper(ModuleType): + """ + Module wrapper that prefers torchada overrides, then torch.accelerator, then + torch.musa fallbacks for APIs not present in older PyTorch builds. + """ + + _REMAP_ATTRS = { + "set_device_index": "set_device", + "set_device_idx": "set_device", + "current_device_index": "current_device", + "current_device_idx": "current_device", + } + _SPECIAL_ATTRS = { + "StreamContext": "core.stream.StreamContext", + } + _MUSA_OVERRIDES = ( + "empty_cache", + "empty_host_cache", + "memory_stats", + "memory_allocated", + "max_memory_allocated", + "memory_reserved", + "max_memory_reserved", + "reset_accumulated_memory_stats", + "reset_peak_memory_stats", + "get_memory_info", + ) + + def __init__(self, original_accel, musa_module): + super().__init__("torch.accelerator") + self._original_accel = original_accel + self._musa_module = musa_module + self._overrides = {} + + for name in self._MUSA_OVERRIDES: + if hasattr(original_accel, name) and hasattr(musa_module, name): + self._set_override(name, getattr(musa_module, name)) + + def _set_override(self, name, value): + """Install an override that takes precedence over wrapped modules.""" + self._overrides[name] = value + object.__setattr__(self, name, value) + + def __getattr__(self, name): + if name in self._overrides: + return self._overrides[name] + + try: + value = getattr(self._original_accel, name) + except AttributeError: + if hasattr(self._musa_module, name): + value = getattr(self._musa_module, name) + elif name in self._SPECIAL_ATTRS: + value = self._musa_module + for part in self._SPECIAL_ATTRS[name].split("."): + value = getattr(value, part) + elif name in self._REMAP_ATTRS: + value = getattr(self._musa_module, self._REMAP_ATTRS[name]) + else: + raise AttributeError(f"module 'torch.accelerator' has no attribute '{name}'") + + object.__setattr__(self, name, value) + return value + + def __dir__(self): + attrs = set(dir(self._original_accel)) + attrs.update(dir(self._musa_module)) + attrs.update(self._REMAP_ATTRS.keys()) + attrs.update(self._SPECIAL_ATTRS.keys()) + attrs.update(self._overrides.keys()) + return list(attrs) + + +_original_torch_accelerator = None + + +def _make_patched_accelerator_synchronize(musa_module): + """Build a synchronize replacement that delegates to ``torch.musa``.""" + from ._device_compat import _translate_device + + def patched_synchronize(device=None): + if device is not None and not isinstance(device, (torch.device, str, int)): + raise TypeError( + f"synchronize() expected device to be torch.device, str, int, or None, " + f"but got {type(device).__name__}" + ) + device = _translate_device(device) + musa_module.synchronize(device) + + return patched_synchronize + + +def _make_accelerator_context_managers(accel_module): + """Build ``device_index`` and ``stream`` context managers bound to wrapper.""" + + class device_index: + """Temporarily set the current accelerator device index.""" + + def __init__(self, idx): + self.idx = idx + self.prev_idx = None + + def __enter__(self): + self.prev_idx = accel_module.current_device_index() + accel_module.set_device_index(self.idx) + return self + + def __exit__(self, *args): + if self.prev_idx is not None: + accel_module.set_device_index(self.prev_idx) + + class stream: + """Temporarily set the current accelerator stream.""" + + def __init__(self, stream_obj): + self.stream = stream_obj + self.prev_stream = None + + def __enter__(self): + self.prev_stream = accel_module.current_stream() + accel_module.set_stream(self.stream) + return self + + def __exit__(self, *args): + if self.prev_stream is not None: + accel_module.set_stream(self.prev_stream) + + return device_index, stream + + +def patch_torch_accelerator(torch_module=torch) -> None: + """Wrap ``torch.accelerator`` with MUSA fallbacks and torchada overrides.""" + global _original_torch_accelerator + + import torch.accelerator as accel + + if _original_torch_accelerator is None: + _original_torch_accelerator = accel + + wrapper = _AcceleratorModuleWrapper(_original_torch_accelerator, torch_module.musa) + wrapper._set_override("synchronize", _make_patched_accelerator_synchronize(torch_module.musa)) + + device_index_cm, stream_cm = _make_accelerator_context_managers(wrapper) + if not hasattr(_original_torch_accelerator, "device_index"): + wrapper._set_override("device_index", device_index_cm) + if not hasattr(_original_torch_accelerator, "stream"): + wrapper._set_override("stream", stream_cm) + + sys.modules["torch.accelerator"] = wrapper + torch_module.accelerator = wrapper + + +def get_original_torch_accelerator(): + """Return the saved original ``torch.accelerator`` module, if patched.""" + return _original_torch_accelerator diff --git a/src/torchada/_cpp_ops.py b/src/torchada/_cpp_ops.py index 4b27778..7a8bd85 100644 --- a/src/torchada/_cpp_ops.py +++ b/src/torchada/_cpp_ops.py @@ -9,20 +9,49 @@ C++ extensions are automatically loaded on MUSA platform when torchada is imported. Usage: - import torchada # C++ extensions are loaded automatically on MUSA + import torchada # C++ extensions are loaded automatically on MUSA. - # Or explicitly load + # Or explicitly load. from torchada._cpp_ops import load_cpp_ops load_cpp_ops() """ import os +import os.path as osp import subprocess -from typing import Optional +from dataclasses import dataclass +from typing import List, Optional _cpp_ops_module: Optional[object] = None _musa_arch_cached: Optional[str] = None +EXTENSION_NAME = "torchada_cpp_ops" +DEFAULT_MUSA_ARCH = "mp_31" + + +@dataclass(frozen=True) +class _ExtensionSources: + """Source files grouped by compiler/toolchain requirements.""" + + csrc_dir: str + cpp_sources: List[str] + musa_sources: List[str] + + @property + def all_sources(self) -> List[str]: + """All source paths in load order.""" + return self.cpp_sources + self.musa_sources + + @property + def has_sources(self) -> bool: + """Whether any extension source files were discovered.""" + return bool(self.cpp_sources or self.musa_sources) + + @property + def needs_musa_loader(self) -> bool: + """Whether MUSA extension loading is required.""" + return bool(self.musa_sources) + def _detect_musa_arch() -> str: """ @@ -41,7 +70,7 @@ def _detect_musa_arch() -> str: if _musa_arch_cached is not None: return _musa_arch_cached - arch = "mp_31" # Default fallback + arch = DEFAULT_MUSA_ARCH try: result = subprocess.run( ["musaInfo"], @@ -51,11 +80,11 @@ def _detect_musa_arch() -> str: ) for line in result.stdout.splitlines(): if "compute capability:" in line.lower(): - # Parse "compute capability: 2.1" + # Parse lines like "compute capability: 2.1". parts = line.split(":") if len(parts) >= 2: version = parts[1].strip() - # Convert "2.1" -> "mp_21", "3.1" -> "mp_31" + # Convert "2.1" to "mp_21", "3.1" to "mp_31", etc. version_parts = version.split(".") if len(version_parts) >= 2: major = version_parts[0].strip() @@ -69,6 +98,62 @@ def _detect_musa_arch() -> str: return arch +def _get_csrc_dir() -> str: + """Return the packaged C++ source directory.""" + return osp.join(osp.dirname(__file__), "csrc") + + +def _discover_extension_sources(csrc_dir: Optional[str] = None) -> _ExtensionSources: + """Discover C++ and MUSA source files for the extension build.""" + csrc_dir = csrc_dir or _get_csrc_dir() + cpp_sources = [] + musa_sources = [] + + for fname in sorted(os.listdir(csrc_dir)): + fpath = osp.join(csrc_dir, fname) + if fname.endswith(".cpp"): + cpp_sources.append(fpath) + elif fname.endswith((".cu", ".mu")): + musa_sources.append(fpath) + + return _ExtensionSources( + csrc_dir=csrc_dir, + cpp_sources=cpp_sources, + musa_sources=musa_sources, + ) + + +def _get_musa_arch_flag() -> str: + """Return the MUSA offload architecture flag for extension compilation.""" + mtgpu_target = os.environ.get("MTGPU_TARGET", "") + if not mtgpu_target: + mtgpu_target = _detect_musa_arch() + return f"--offload-arch={mtgpu_target}" + + +def _load_extension(sources: _ExtensionSources, verbose: bool): + """Load the extension with the toolchain required by the discovered sources.""" + if sources.needs_musa_loader: + from .utils.cpp_extension import load + + return load( + name=EXTENSION_NAME, + sources=sources.all_sources, + extra_include_paths=[sources.csrc_dir], + extra_cuda_cflags=[_get_musa_arch_flag()], + verbose=verbose, + ) + + from torch.utils.cpp_extension import load + + return load( + name=EXTENSION_NAME, + sources=sources.all_sources, + extra_include_paths=[sources.csrc_dir], + verbose=verbose, + ) + + def load_cpp_ops(force_reload: bool = False) -> Optional[object]: """ Load the C++ operator overrides extension. @@ -86,68 +171,21 @@ def load_cpp_ops(force_reload: bool = False) -> Optional[object]: if _cpp_ops_module is not None and not force_reload: return _cpp_ops_module - # Check if on MUSA platform from ._platform import is_musa_platform if not is_musa_platform(): return None try: - import os.path as osp - - csrc_dir = osp.join(osp.dirname(__file__), "csrc") - - # Collect all source files - cpp_sources = [] - musa_sources = [] - - for fname in os.listdir(csrc_dir): - fpath = osp.join(csrc_dir, fname) - if fname.endswith(".cpp"): - cpp_sources.append(fpath) - elif fname.endswith((".cu", ".mu")): - musa_sources.append(fpath) - - if not cpp_sources and not musa_sources: + sources = _discover_extension_sources() + if not sources.has_sources: import warnings warnings.warn("torchada C++ ops: no source files found") return None verbose = os.environ.get("TORCHADA_CPP_OPS_VERBOSE") == "1" - all_sources = cpp_sources + musa_sources - - # Use MUSA extension loader if we have MUSA sources, otherwise use torch's - if musa_sources: - # Use torchada's load which handles MUSA properly - from .utils.cpp_extension import load - - # Get MUSA architecture flags - # Use MTGPU_TARGET env var if set, otherwise auto-detect from GPU - extra_cuda_cflags = [] - mtgpu_target = os.environ.get("MTGPU_TARGET", "") - if not mtgpu_target: - mtgpu_target = _detect_musa_arch() - extra_cuda_cflags.append(f"--offload-arch={mtgpu_target}") - - _cpp_ops_module = load( - name="torchada_cpp_ops", - sources=all_sources, - extra_include_paths=[csrc_dir], - extra_cuda_cflags=extra_cuda_cflags, - verbose=verbose, - ) - else: - # Pure C++ extension - use torch's loader directly - from torch.utils.cpp_extension import load - - _cpp_ops_module = load( - name="torchada_cpp_ops", - sources=all_sources, - extra_include_paths=[csrc_dir], - verbose=verbose, - ) - + _cpp_ops_module = _load_extension(sources, verbose) _cpp_ops_module._mark_loaded() return _cpp_ops_module diff --git a/src/torchada/_ctypes_compat.py b/src/torchada/_ctypes_compat.py new file mode 100644 index 0000000..ee4df1f --- /dev/null +++ b/src/torchada/_ctypes_compat.py @@ -0,0 +1,81 @@ +""" +ctypes compatibility for CUDA-named symbols in MUSA runtime libraries. + +This module keeps ctypes callers using CUDA-family names while dispatching to +MUSA runtime-family libraries. +""" + +from ._runtime import ( + detect_musa_library_type, + is_musa_runtime_library_path, + translate_runtime_symbol_name, +) + + +class _CDLLWrapper: + """ + Wrapper for ``ctypes.CDLL`` that translates CUDA-family symbol names. + + Loading MUSA libraries such as ``libmusart.so`` or ``libmccl.so`` still lets + callers access symbols with CUDA/NCCL names. + """ + + def __init__(self, cdll_instance, lib_path: str): + object.__setattr__(self, "_cdll", cdll_instance) + object.__setattr__(self, "_lib_path", lib_path) + object.__setattr__(self, "_lib_type", self._detect_lib_type(lib_path)) + + def _detect_lib_type(self, lib_path: str) -> str: + """Detect the MUSA runtime-family library type from its path.""" + return detect_musa_library_type(lib_path) + + def _translate_name(self, name: str) -> str: + """Translate CUDA-family symbol names for the wrapped library.""" + lib_type = object.__getattribute__(self, "_lib_type") + return translate_runtime_symbol_name(name, lib_type) + + def __getattr__(self, name: str): + cdll = object.__getattribute__(self, "_cdll") + value = getattr(cdll, self._translate_name(name)) + object.__setattr__(self, name, value) + return value + + def __setattr__(self, name: str, value): + cdll = object.__getattribute__(self, "_cdll") + setattr(cdll, self._translate_name(name), value) + + def __getitem__(self, name: str): + cdll = object.__getattribute__(self, "_cdll") + return cdll[self._translate_name(name)] + + +_original_ctypes_CDLL = None + + +def patch_ctypes_cdll() -> None: + """Patch ``ctypes.CDLL`` to wrap MUSA runtime-family libraries.""" + import ctypes + + global _original_ctypes_CDLL + + if _original_ctypes_CDLL is not None: + return + + _original_ctypes_CDLL = ctypes.CDLL + + class PatchedCDLL: + """Patched CDLL constructor that wraps MUSA runtime libraries.""" + + def __new__(cls, name, *args, **kwargs): + cdll_instance = _original_ctypes_CDLL(name, *args, **kwargs) + name_str = str(name) if name else "" + if is_musa_runtime_library_path(name_str): + return _CDLLWrapper(cdll_instance, name_str) + return cdll_instance + + ctypes.CDLL = PatchedCDLL + + +def get_original_ctypes_cdll(): + """Return the saved original ``ctypes.CDLL`` constructor, if patched.""" + return _original_ctypes_CDLL diff --git a/src/torchada/_cuda_compat.py b/src/torchada/_cuda_compat.py new file mode 100644 index 0000000..8ae1a1f --- /dev/null +++ b/src/torchada/_cuda_compat.py @@ -0,0 +1,353 @@ +""" +CUDA module compatibility helpers for torchada patching. + +The patch registry in ``_patch.py`` is responsible for when patching happens. +This module owns the CUDA-shaped shims that get installed onto ``torch.musa`` +so the registry stays readable as compatibility coverage grows. +""" + +import sys +from collections import OrderedDict +from types import ModuleType +from typing import Any, Callable, Optional + +from ._runtime import cuda_to_musa_name + +DeviceTranslator = Callable[[Any], Any] + +_MUSA_SYNC_DEBUG_MODE = 0 +_SYNC_DEBUG_MODE_VALUES = { + "default": 0, + "warn": 1, + "error": 2, +} + + +class _CudartWrapper: + """ + Wrapper for CUDA runtime that translates calls to MUSA runtime. + + This allows code like ``torch.cuda.cudart().cudaHostRegister(...)`` to work + on MUSA by translating to ``torch_musa.musart().musaHostRegister(...)``. + Resolved attributes are cached on the wrapper. + """ + + def __init__(self, musart_module): + self._musart = musart_module + + def __getattr__(self, name): + translated_name = cuda_to_musa_name(name) + candidates = [translated_name] + if translated_name != name: + candidates.append(name) + + for candidate in candidates: + try: + value = getattr(self._musart, candidate) + except AttributeError: + continue + object.__setattr__(self, name, value) + return value + + raise AttributeError(f"CUDA runtime has no attribute '{name}'") + + +class _CudaModuleWrapper(ModuleType): + """ + Module wrapper that redirects torch.cuda to torch.musa while preserving + CUDA-only detection APIs that downstream packages rely on. + + ``torch.cuda.is_available()`` intentionally keeps the original CUDA + behavior. Everything else resolves through torch.musa, with a few explicit + remaps for CUDA/MUSA naming differences. + """ + + _NO_REDIRECT = {"is_available"} + _SPECIAL_ATTRS = { + "StreamContext": "core.stream.StreamContext", + } + _REMAP_ATTRS = { + "_device_count_nvml": "device_count", + "nccl": "mccl", + } + _NO_CACHE = set() + + def __init__(self, original_cuda, musa_module): + super().__init__("torch.cuda") + self._original_cuda = original_cuda + self._musa_module = musa_module + self._cudart_wrapper = None + + def cudart(self): + """Return a CUDA runtime wrapper that delegates to MUSA runtime APIs.""" + if self._cudart_wrapper is None: + if hasattr(self._musa_module, "musart"): + self._cudart_wrapper = _CudartWrapper(self._musa_module.musart()) + else: + return self._original_cuda.cudart() + return self._cudart_wrapper + + def __getattr__(self, name): + if name in self._NO_REDIRECT: + value = getattr(self._original_cuda, name) + elif name in self._SPECIAL_ATTRS: + value = self._musa_module + for part in self._SPECIAL_ATTRS[name].split("."): + value = getattr(value, part) + elif name in self._REMAP_ATTRS: + value = getattr(self._musa_module, self._REMAP_ATTRS[name]) + else: + value = getattr(self._musa_module, name) + + if name not in self._NO_CACHE: + object.__setattr__(self, name, value) + return value + + def __dir__(self): + attrs = set(dir(self._musa_module)) + attrs.update(self._NO_REDIRECT) + attrs.update(self._SPECIAL_ATTRS.keys()) + attrs.update(self._REMAP_ATTRS.keys()) + attrs.add("cudart") + return list(attrs) + + +def _musa_get_gencode_flags() -> str: + """ + Return CUDA-style gencode flags for MUSA. + + CUDA's implementation returns NVCC flags. Those flags should not be passed + to the MUSA toolchain, so compatibility behavior matches a non-CUDA build + and returns an empty string while preserving the API surface. + """ + return "" + + +def _musa_get_sync_debug_mode() -> int: + """Return the process-local CUDA sync debug mode shim value.""" + return _MUSA_SYNC_DEBUG_MODE + + +def _musa_set_sync_debug_mode(debug_mode) -> None: + """ + Set a process-local CUDA sync debug mode shim value. + + torch_musa does not expose CUDA's C++ sync-debug hooks. Keeping the value in + Python preserves the public setter/getter contract without pretending to + alter MUSA runtime synchronization behavior. + """ + global _MUSA_SYNC_DEBUG_MODE + + if isinstance(debug_mode, str): + if debug_mode not in _SYNC_DEBUG_MODE_VALUES: + raise RuntimeError( + "invalid value of debug_mode, expected one of `default`, `warn`, `error`" + ) + _MUSA_SYNC_DEBUG_MODE = _SYNC_DEBUG_MODE_VALUES[debug_mode] + return None + + if isinstance(debug_mode, int) and debug_mode in _SYNC_DEBUG_MODE_VALUES.values(): + _MUSA_SYNC_DEBUG_MODE = int(debug_mode) + return None + + raise RuntimeError("invalid value of debug_mode, expected one of `default`, `warn`, `error`") + + +def _host_memory_stats() -> OrderedDict: + """Return empty host allocator stats when MUSA exposes no host counters.""" + return OrderedDict() + + +def _host_memory_stats_as_nested_dict() -> dict: + """Return empty nested host allocator stats when MUSA exposes no counters.""" + return {} + + +def _reset_host_memory_stats() -> None: + """No-op reset for unavailable MUSA host allocator counters.""" + return None + + +def _make_memory_cached(torch_module, translate_device: DeviceTranslator): + def memory_cached(device=None) -> int: + """Deprecated CUDA alias for memory_reserved(), mapped to MUSA.""" + return torch_module.musa.memory_reserved(translate_device(device)) + + return memory_cached + + +def _make_max_memory_cached(torch_module, translate_device: DeviceTranslator): + def max_memory_cached(device=None) -> int: + """Deprecated CUDA alias for max_memory_reserved(), mapped to MUSA.""" + return torch_module.musa.max_memory_reserved(translate_device(device)) + + return max_memory_cached + + +def _make_get_stream_from_external(torch_module, translate_device: DeviceTranslator): + def get_stream_from_external(data_ptr: int, device=None): + """Wrap an externally allocated MUSA stream using CUDA-compatible API naming.""" + return torch_module.musa.ExternalStream(data_ptr, device=translate_device(device)) + + return get_stream_from_external + + +def _build_musa_sparse_module(musa_module) -> ModuleType: + """Create a torch.cuda.sparse-compatible module backed by MUSA tensor classes.""" + sparse_module = ModuleType("torch.cuda.sparse") + for name in [ + "ByteTensor", + "CharTensor", + "DoubleTensor", + "FloatTensor", + "HalfTensor", + "IntTensor", + "LongTensor", + "ShortTensor", + "BFloat16Tensor", + ]: + if hasattr(musa_module, name): + setattr(sparse_module, name, getattr(musa_module, name)) + return sparse_module + + +def install_cuda_memory_compat( + torch_module, + cpp_ops_module: Optional[Any], + translate_device: DeviceTranslator, +) -> None: + """Install torch.cuda.memory-compatible aliases onto torch.musa.memory.""" + if not hasattr(torch_module.musa, "memory"): + return + + musa_memory_module = torch_module.musa.memory + if musa_memory_module is None: + return + + sys.modules["torch.cuda.memory"] = musa_memory_module + + if hasattr(musa_memory_module, "MUSAPluggableAllocator"): + musa_memory_module.CUDAPluggableAllocator = musa_memory_module.MUSAPluggableAllocator + torch_module.musa.CUDAPluggableAllocator = musa_memory_module.MUSAPluggableAllocator + + memory_aliases = { + "memory_cached": _make_memory_cached(torch_module, translate_device), + "max_memory_cached": _make_max_memory_cached(torch_module, translate_device), + "host_memory_stats": _host_memory_stats, + "host_memory_stats_as_nested_dict": _host_memory_stats_as_nested_dict, + "reset_accumulated_host_memory_stats": _reset_host_memory_stats, + "reset_peak_host_memory_stats": _reset_host_memory_stats, + } + for name, func in memory_aliases.items(): + if not hasattr(torch_module.musa, name): + setattr(torch_module.musa, name, func) + if not hasattr(musa_memory_module, name): + setattr(musa_memory_module, name, func) + + if cpp_ops_module is None: + return + + for func_name in [ + "_cuda_beginAllocateCurrentThreadToPool", + "_cuda_endAllocateToPool", + "_cuda_releasePool", + ]: + func = getattr(cpp_ops_module, func_name, None) + if func is not None: + setattr(musa_memory_module, func_name, func) + + +def install_cuda_module_aliases(torch_module) -> None: + """Register import aliases for CUDA submodules that map to MUSA modules.""" + musa_module = torch_module.musa + + if hasattr(musa_module, "amp"): + sys.modules["torch.cuda.amp"] = musa_module.amp + + if hasattr(musa_module, "graphs"): + sys.modules["torch.cuda.graphs"] = musa_module.graphs + + if hasattr(musa_module, "MUSAGraph") and not hasattr(musa_module, "CUDAGraph"): + musa_module.CUDAGraph = musa_module.MUSAGraph + + if hasattr(musa_module, "mccl"): + sys.modules["torch.cuda.nccl"] = musa_module.mccl + if not hasattr(musa_module, "nccl"): + musa_module.nccl = musa_module.mccl + + try: + import torch_musa.core.stream as musa_stream_module + + sys.modules["torch.cuda.streams"] = musa_stream_module + if not hasattr(musa_module, "streams"): + musa_module.streams = musa_stream_module + except ImportError: + pass + + if not hasattr(musa_module, "sparse"): + musa_module.sparse = _build_musa_sparse_module(musa_module) + sys.modules["torch.cuda.sparse"] = musa_module.sparse + + if hasattr(musa_module, "profiler"): + sys.modules["torch.cuda.profiler"] = musa_module.profiler + + try: + from .cuda import nvtx as nvtx_stub + + sys.modules["torch.cuda.nvtx"] = nvtx_stub + musa_module.nvtx = nvtx_stub + except ImportError: + pass + + if hasattr(musa_module, "random"): + sys.modules["torch.cuda.random"] = musa_module.random + else: + try: + from .cuda import random as random_stub + + sys.modules["torch.cuda.random"] = random_stub + musa_module.random = random_stub + except ImportError: + pass + + +def install_cuda_public_api_shims(torch_module, translate_device: DeviceTranslator) -> None: + """Install top-level CUDA API shims that torch_musa does not expose.""" + musa_module = torch_module.musa + + try: + from torch_musa.core._lazy_init import _lazy_call + + if not hasattr(musa_module, "_lazy_call"): + musa_module._lazy_call = _lazy_call + except ImportError: + pass + + if not hasattr(musa_module, "_is_compiled"): + musa_module._is_compiled = lambda: True + + if not hasattr(musa_module, "has_half"): + musa_module.has_half = True + if not hasattr(musa_module, "has_magma"): + musa_module.has_magma = False + if hasattr(musa_module, "_lazy_init") and not hasattr(musa_module, "init"): + musa_module.init = musa_module._lazy_init + if not hasattr(musa_module, "get_stream_from_external"): + musa_module.get_stream_from_external = _make_get_stream_from_external( + torch_module, translate_device + ) + + try: + from torch_musa.core._lazy_init import default_generators + + if not hasattr(musa_module, "default_generators"): + musa_module.default_generators = default_generators + except ImportError: + pass + + if not hasattr(musa_module, "get_gencode_flags"): + musa_module.get_gencode_flags = _musa_get_gencode_flags + if not hasattr(musa_module, "get_sync_debug_mode"): + musa_module.get_sync_debug_mode = _musa_get_sync_debug_mode + if not hasattr(musa_module, "set_sync_debug_mode"): + musa_module.set_sync_debug_mode = _musa_set_sync_debug_mode diff --git a/src/torchada/_device_compat.py b/src/torchada/_device_compat.py new file mode 100644 index 0000000..3572c00 --- /dev/null +++ b/src/torchada/_device_compat.py @@ -0,0 +1,301 @@ +""" +Device and tensor-constructor compatibility helpers. + +This module owns the low-level CUDA-device-to-MUSA-device translation used by +the patch registry. Keeping the state here avoids mixing global patch state +with the orchestration code in ``_patch.py``. +""" + +import functools +from typing import Any, Callable, Optional + +import torch + +from ._platform import is_musa_platform + +_device_str_cache = {} +_is_musa_platform_cached: Optional[bool] = None + +_original_torch_device = None +_original_torch_generator = None +_original_c_generator = None + + +def _translate_device(device: Any) -> Any: + """ + Translate ``cuda`` device references to ``musa`` on MUSA platforms. + + Strings are cached because this sits on hot paths such as ``Tensor.to`` and + tensor factory calls. + """ + global _is_musa_platform_cached + + if _is_musa_platform_cached is None: + _is_musa_platform_cached = is_musa_platform() + + if not _is_musa_platform_cached or device is None: + return device + + if isinstance(device, str): + if device in _device_str_cache: + return _device_str_cache[device] + if device == "cuda" or device.startswith("cuda:"): + result = device.replace("cuda", "musa") + else: + result = device + _device_str_cache[device] = result + return result + + if isinstance(device, torch.device): + if device.type == "cuda": + return torch.device("musa", device.index) + return device + + return device + + +def _wrap_to_method(original_to: Callable) -> Callable: + """Wrap ``Tensor.to`` to translate CUDA device arguments.""" + + @functools.wraps(original_to) + def wrapped_to(self, *args, **kwargs): + if args: + first_arg = args[0] + if isinstance(first_arg, (str, torch.device)): + args = (_translate_device(first_arg),) + args[1:] + elif isinstance(first_arg, torch.dtype) and len(args) >= 2: + args = (first_arg, _translate_device(args[1])) + args[2:] + + if "device" in kwargs: + kwargs["device"] = _translate_device(kwargs["device"]) + + return original_to(self, *args, **kwargs) + + return wrapped_to + + +def _musa_device_spec(device: Any) -> Any: + """Return a device spec suitable for ``.to`` when ``.musa`` is unavailable.""" + if device is None: + return "musa" + if isinstance(device, int): + return f"musa:{device}" + return device + + +def _wrap_tensor_cuda(original_cuda: Callable) -> Callable: + """Wrap ``Tensor.cuda`` to use MUSA on MUSA platforms.""" + _is_musa = is_musa_platform() + + @functools.wraps(original_cuda) + def wrapped_cuda(self, device=None, non_blocking=False, memory_format=torch.preserve_format): + if _is_musa: + device = _translate_device(device) + if hasattr(self, "musa"): + kwargs = {"device": device, "non_blocking": non_blocking} + if memory_format is not torch.preserve_format: + kwargs["memory_format"] = memory_format + return self.musa(**kwargs) + target_device = _musa_device_spec(device) + return self.to( + target_device, + non_blocking=non_blocking, + memory_format=memory_format, + ) + + kwargs = {"device": device, "non_blocking": non_blocking} + if memory_format is not torch.preserve_format: + kwargs["memory_format"] = memory_format + return original_cuda(self, **kwargs) + + return wrapped_cuda + + +def _wrap_module_cuda(original_cuda: Callable) -> Callable: + """Wrap ``nn.Module.cuda`` to use MUSA on MUSA platforms.""" + _is_musa = is_musa_platform() + + @functools.wraps(original_cuda) + def wrapped_cuda(self, device=None): + if _is_musa: + device = _translate_device(device) + if hasattr(self, "musa"): + return self.musa(device=device) + target_device = _musa_device_spec(device) + return self.to(target_device) + return original_cuda(self, device=device) + + return wrapped_cuda + + +class _DeviceFactoryMeta(type): + """Metaclass that keeps ``isinstance(x, torch.device)`` working.""" + + def __instancecheck__(cls, instance): + if _original_torch_device is not None: + return isinstance(instance, _original_torch_device) + return False + + def __subclasscheck__(cls, subclass): + if _original_torch_device is not None: + return issubclass(subclass, _original_torch_device) + return False + + +class DeviceFactoryWrapper(metaclass=_DeviceFactoryMeta): + """ + Drop-in ``torch.device`` factory that translates CUDA devices to MUSA. + """ + + _original = None + + def __new__(cls, device=None, index=None, *, type=None): + original = cls._original + if original is None: + raise RuntimeError("DeviceFactoryWrapper not initialized") + + if type is not None: + device = type + + if isinstance(device, original): + if device.type == "cuda": + index = device.index if index is None else index + device = "musa" + else: + return device + + if isinstance(device, str): + device = _translate_device(device) + + if index is not None: + return original(device, index) + if device is not None: + return original(device) + return original() + + +def patch_torch_device(torch_module=torch) -> None: + """Patch ``torch.device`` with ``DeviceFactoryWrapper``.""" + global _original_torch_device + + if _original_torch_device is not None: + return + + _original_torch_device = torch_module.device + DeviceFactoryWrapper._original = _original_torch_device + torch_module.device = DeviceFactoryWrapper + + +class _GeneratorMeta(type): + """Metaclass that preserves ``isinstance(x, torch.Generator)`` behavior.""" + + def __instancecheck__(cls, instance): + if _original_c_generator is not None: + return isinstance(instance, _original_c_generator) + return False + + def __subclasscheck__(cls, subclass): + if _original_c_generator is not None and subclass is _original_c_generator: + return True + return super().__subclasscheck__(subclass) + + +class GeneratorWrapper(metaclass=_GeneratorMeta): + """Wrapper for ``torch.Generator`` that translates CUDA devices to MUSA.""" + + _original = None + + def __new__(cls, device=None): + original = cls._original + if original is None: + raise RuntimeError("GeneratorWrapper not initialized") + if device is not None: + device = _translate_device(device) + return original(device=device) + + +def patch_torch_generator(torch_module=torch) -> None: + """Patch ``torch.Generator`` with ``GeneratorWrapper``.""" + global _original_torch_generator, _original_c_generator + + if _original_torch_generator is not None: + return + + _original_torch_generator = torch_module.Generator + _original_c_generator = torch_module._C.Generator + + GeneratorWrapper._original = _original_torch_generator + GeneratorWrapper.__doc__ = _original_torch_generator.__doc__ + + torch_module.Generator = GeneratorWrapper + + +def _wrap_factory_function(original_fn: Callable) -> Callable: + """Wrap tensor factory functions to translate ``device=`` arguments.""" + + @functools.wraps(original_fn) + def wrapped_fn(*args, **kwargs): + if "device" in kwargs: + kwargs["device"] = _translate_device(kwargs["device"]) + return original_fn(*args, **kwargs) + + return wrapped_fn + + +_FACTORY_FUNCTIONS = [ + "tensor", + "as_tensor", + "asarray", + "empty", + "zeros", + "ones", + "full", + "rand", + "randn", + "randint", + "randperm", + "normal", + "arange", + "range", + "linspace", + "logspace", + "eye", + "empty_strided", + "empty_permuted", + "from_file", + "empty_like", + "zeros_like", + "ones_like", + "full_like", + "rand_like", + "randn_like", + "randint_like", + "sparse_coo_tensor", + "sparse_csr_tensor", + "sparse_csc_tensor", + "sparse_bsr_tensor", + "sparse_bsc_tensor", + "sparse_compressed_tensor", + "tril_indices", + "triu_indices", + "bartlett_window", + "blackman_window", + "hamming_window", + "hann_window", + "kaiser_window", +] + + +def get_original_torch_device(): + """Return the saved original ``torch.device`` factory, if patched.""" + return _original_torch_device + + +def get_original_torch_generator(): + """Return the saved original ``torch.Generator`` factory, if patched.""" + return _original_torch_generator + + +def get_original_c_generator(): + """Return the saved original C generator type, if patched.""" + return _original_c_generator diff --git a/src/torchada/_mapping.py b/src/torchada/_mapping.py index 1d22a5e..cbf8b69 100644 --- a/src/torchada/_mapping.py +++ b/src/torchada/_mapping.py @@ -1,15 +1,15 @@ """ -CUDA to MUSA mapping rules for source code porting. +CUDA-to-MUSA source-porting mappings. -This module contains the comprehensive mapping dictionary for converting -CUDA-specific symbols to their MUSA equivalents during extension builds. +The extension builder applies these rules before handing CUDA-oriented source +trees to the MUSA toolchain. """ -# Extension file suffix mappings -# Convert .cu/.cuh to .mu/.muh so torch_musa's musa_compile rule works correctly -# The musa_compile rule in torch_musa only adds -x musa for .mu/.muh files -# Without this conversion, .cu files would be treated as CUDA files by mcc -# and the --offload-arch=mp_XX flag would fail with clang's CUDA support +# File suffixes that require MUSA compilation mode. +# +# torch_musa's musa_compile rule only adds ``-x musa`` for ``.mu`` and ``.muh`` +# files. Converting CUDA suffixes first keeps ``mcc`` from treating those files +# as NVIDIA CUDA sources. EXT_REPLACED_MAPPING = { "cuh": "muh", "cu": "mu", @@ -18,11 +18,9 @@ "cxx": "cxx", } -# Comprehensive CUDA to MUSA symbol mapping +# CUDA-to-MUSA symbol mappings grouped by source family. _MAPPING_RULE = { - # ========================================================================= - # ATen mappings - # ========================================================================= + # ATen mappings. "#include ": '#include "torch_musa/share/generated_cuda_compatible/include/ATen/musa/Atomic.muh"', "#include ": '#include "torch_musa/csrc/aten/musa/MUSAContext.h"', "#include ": '#include "torch_musa/csrc/aten/musa/MUSADtype.muh"', @@ -31,12 +29,10 @@ "#include ": '#include "torch_musa/csrc/aten/musa/UnpackRaw.muh"', "#include ": '#include "torch_musa/csrc/aten/musa/Exceptions.h"', "at::cuda": "at::musa", - # File extension mappings for include statements (.cuh -> .muh) + # Include suffix mappings for CUDA headers. '.cuh"': '.muh"', ".cuh>": ".muh>", - # ========================================================================= - # C10 mappings - # ========================================================================= + # C10 mappings. "#include ": '#include "torch_musa/csrc/core/MUSAException.h"', "#include ": '#include "torch_musa/csrc/core/MUSAGuard.h"', "#include ": '#include "torch_musa/csrc/core/MUSAStream.h"', @@ -45,7 +41,7 @@ "C10_CUDA_CHECK": "C10_MUSA_CHECK", "C10_CUDA_ERROR_HANDLED": "C10_MUSA_ERROR_HANDLED", "C10_CUDA_IGNORE_ERROR": "C10_MUSA_IGNORE_ERROR", - # Header file mappings (must come before generic c10/cuda mapping) + # Header mappings must come before the generic c10/cuda path mapping. "c10/cuda/CUDAException.h": "c10/musa/MUSAException.h", "c10/cuda/CUDAStream.h": "c10/musa/MUSAStream.h", "c10/cuda/CUDAGuard.h": "c10/musa/MUSAGuard.h", @@ -54,27 +50,22 @@ "c10/cuda/CUDACachingAllocator.h": "c10/musa/MUSACachingAllocator.h", "": '"torch_musa/csrc/core/MUSAStream.h"', "c10/cuda": "c10/musa", - # ========================================================================= - # CUDA standard library - # ========================================================================= + # CUDA standard library mappings. "cuda/std": "musa/std", "": "", " muBLAS - # ========================================================================= + # cuBLAS to muBLAS mappings. "cublas": "mublas", "CUBLAS": "MUBLAS", "cublasHandle_t": "mublasHandle_t", @@ -93,7 +84,7 @@ "cublasDgemmBatched": "mublasDgemmBatched", "cublasSgemmStridedBatched": "mublasSgemmStridedBatched", "cublasDgemmStridedBatched": "mublasDgemmStridedBatched", - # cuBLASLt + # cuBLASLt mappings. "CUBLASLT_MATMUL_DESC_A_SCALE_POINTER": "MUBLASLT_MATMUL_DESC_A_SCALE_POINTER", "CUBLASLT_MATMUL_DESC_B_SCALE_POINTER": "MUBLASLT_MATMUL_DESC_B_SCALE_POINTER", "CUBLASLT_MATMUL_DESC_FAST_ACCUM": "MUBLASLT_MATMUL_DESC_FAST_ACCUM", @@ -131,9 +122,7 @@ "cublasLtMatrixLayoutOpaque_t": "mublasLtMatrixLayoutOpaque_t", "cublasLtMatrixLayoutSetAttribute": "mublasLtMatrixLayoutSetAttribute", "cublasLtMatrixLayout_t": "mublasLtMatrixLayout_t", - # ========================================================================= - # cuRAND -> muRAND - # ========================================================================= + # cuRAND to muRAND mappings. "curand": "murand", "CURAND": "MURAND", "curandState": "murandState", @@ -143,17 +132,13 @@ "curand_uniform4": "murand_uniform4", "curand_normal": "murand_normal", "curand_normal4": "murand_normal4", - # ========================================================================= - # cuDNN -> muDNN - # ========================================================================= + # cuDNN to muDNN mappings. "cudnn": "mudnn", "CUDNN": "MUDNN", "cudnnHandle_t": "mudnnHandle_t", "cudnnCreate": "mudnnCreate", "cudnnDestroy": "mudnnDestroy", - # ========================================================================= - # CUDA Runtime API - # ========================================================================= + # CUDA runtime API mappings. "cudaMalloc": "musaMalloc", "cudaFree": "musaFree", "cudaMemcpy": "musaMemcpy", @@ -167,7 +152,7 @@ "cudaGetDeviceCount": "musaGetDeviceCount", "cudaGetDeviceProperties": "musaGetDeviceProperties", "cudaDeviceGetAttribute": "musaDeviceGetAttribute", - # Stream/Event + # Stream and event mappings. "cudaStream_t": "musaStream_t", "cudaEvent_t": "musaEvent_t", "cudaStreamCreate": "musaStreamCreate", @@ -178,37 +163,35 @@ "cudaEventSynchronize": "musaEventSynchronize", "cudaEventElapsedTime": "musaEventElapsedTime", "cudaStreamWaitEvent": "musaStreamWaitEvent", - # Error handling + # Error handling mappings. "cudaError_t": "musaError_t", "cudaSuccess": "musaSuccess", "cudaGetLastError": "musaGetLastError", "cudaGetErrorString": "musaGetErrorString", "cudaPeekAtLastError": "musaPeekAtLastError", - # Memory types + # Memory type mappings. "cudaMemcpyHostToDevice": "musaMemcpyHostToDevice", "cudaMemcpyDeviceToHost": "musaMemcpyDeviceToHost", "cudaMemcpyDeviceToDevice": "musaMemcpyDeviceToDevice", "cudaMemcpyHostToHost": "musaMemcpyHostToHost", - # Constants - memory allocation + # Memory allocation constants. "cudaFuncAttributeMaxDynamicSharedMemorySize": "musaFuncAttributeMaxDynamicSharedMemorySize", - # Unsupported APIs - map to no-ops or placeholders + # Unsupported runtime APIs mapped to no-ops. "cudaGridDependencySynchronize()": "((void)0)", "cudaTriggerProgrammaticLaunchCompletion()": "((void)0)", - # ========================================================================= - # Data types - # ========================================================================= - # BFloat16 types (Note: __half and half are the same in MUSA, no mapping needed) + # Data type mappings. + # BFloat16 types; ``__half`` and ``half`` use the same spelling in MUSA. "__nv_bfloat16": "__mt_bfloat16", "__nv_bfloat162": "__mt_bfloat162", "__nv_half": "__half", "nv_bfloat16": "__mt_bfloat16", "nv_bfloat162": "__mt_bfloat162", "nv_half": "__half", - # FP8 data types - constants + # FP8 data type constants. "__NV_E4M3": "__MT_E4M3", "__NV_E5M2": "__MT_E5M2", "__NV_SATFINITE": "__MT_SATFINITE", - # FP8 data types - types (alphabetical) + # FP8 data types. "__nv_fp8_e4m3": "__mt_fp8_e4m3", "__nv_fp8_e5m2": "__mt_fp8_e5m2", "__nv_fp8_interpretation_t": "__mt_fp8_interpretation_t", @@ -219,19 +202,17 @@ "__nv_fp8x4_e4m3": "__mt_fp8x4_e4m3", "__nv_fp8x4_e5m2": "__mt_fp8x4_e5m2", "__nv_fp8x4_storage_t": "__mt_fp8x4_storage_t", - # FP8 data types - conversion functions (alphabetical) + # FP8 conversion functions. "__nv_cvt_bfloat16raw_to_fp8": "__musa_cvt_bfloat16raw_to_fp8", "__nv_cvt_float2_to_fp8x2": "__musa_cvt_float2_to_fp8x2", "__nv_cvt_float_to_fp8": "__musa_cvt_float_to_fp8", "__nv_cvt_fp8_to_halfraw": "__musa_cvt_fp8_to_halfraw", "__nv_cvt_fp8x2_to_halfraw2": "__musa_cvt_fp8x2_to_halfraw2", - # FP8 data types - includes and enums + # FP8 includes and enums. "#include ": "#include ", "CUDA_R_8F_E4M3": "MUSA_R_8F_E4M3", "CUDA_R_8F_E5M2": "MUSA_R_8F_E5M2", - # ========================================================================= - # Cutlass -> Mutlass - # ========================================================================= + # Cutlass to Mutlass mappings. '#include "cutlass/array.h"': "#include ", "#include ": "#include ", "#include ": "#include ", @@ -243,14 +224,9 @@ "CUTLASS": "MUTLASS", "cutlass/": "mutlass/", "cutlass::": "mutlass::", - # ========================================================================= - # Thrust - # ========================================================================= - # CUB - MUSA provides cub directly, no conversion needed + # Thrust mappings; CUB is provided directly by MUSA. "thrust::cuda": "thrust::musa", - # ========================================================================= - # NCCL -> MCCL - # ========================================================================= + # NCCL to MCCL mappings. "nccl": "mccl", "NCCL": "MCCL", "ncclComm_t": "mcclComm_t", @@ -258,25 +234,19 @@ "ncclRedOp_t": "mcclRedOp_t", "ncclResult_t": "mcclResult_t", "ncclSuccess": "mcclSuccess", - # ========================================================================= - # cuSPARSE -> muSPARSE - # ========================================================================= + # cuSPARSE to muSPARSE mappings. "cusparse": "musparse", "CUSPARSE": "MUSPARSE", "cusparseHandle_t": "musparseHandle_t", "cusparseCreate": "musparseCreate", "cusparseDestroy": "musparseDestroy", - # ========================================================================= - # cuSOLVER -> muSOLVER - # ========================================================================= + # cuSOLVER to muSOLVER mappings. "cusolver": "musolver", "CUSOLVER": "MUSOLVER", "cusolverDnHandle_t": "musolverDnHandle_t", "cusolverDnCreate": "musolverDnCreate", "cusolverDnDestroy": "musolverDnDestroy", - # ========================================================================= - # cuFFT -> muFFT - # ========================================================================= + # cuFFT to muFFT mappings. "cufft": "mufft", "CUFFT": "MUFFT", "cufftHandle": "mufftHandle", @@ -286,9 +256,7 @@ "cufftExecC2C": "mufftExecC2C", "cufftExecR2C": "mufftExecR2C", "cufftExecC2R": "mufftExecC2R", - # ========================================================================= - # CUDA device attributes - # ========================================================================= + # CUDA device attribute mappings. "cudaDevAttrMaxThreadsPerBlock": "musaDevAttrMaxThreadsPerBlock", "cudaDevAttrMaxBlockDimX": "musaDevAttrMaxBlockDimX", "cudaDevAttrMaxBlockDimY": "musaDevAttrMaxBlockDimY", @@ -299,9 +267,7 @@ "cudaDevAttrMaxSharedMemoryPerBlock": "musaDevAttrMaxSharedMemoryPerBlock", "cudaDevAttrWarpSize": "musaDevAttrWarpSize", "cudaDevAttrMultiProcessorCount": "musaDevAttrMultiProcessorCount", - # ========================================================================= - # PyTorch CUDA utilities - # ========================================================================= + # PyTorch CUDA utility mappings. "getCurrentCUDAStream": "getCurrentMUSAStream", "getDefaultCUDAStream": "getDefaultMUSAStream", "CUDAStream": "MUSAStream", @@ -309,17 +275,13 @@ "OptionalCUDAGuard": "OptionalMUSAGuard", "CUDAStreamGuard": "MUSAStreamGuard", "CUDAEvent": "MUSAEvent", - # ========================================================================= - # CUDA header includes - # ========================================================================= + # CUDA header include mappings. "cuda_runtime.h": "musa_runtime.h", "cuda_runtime_api.h": "musa_runtime_api.h", "cuda.h": "musa.h", "cuda_fp16.h": "musa_fp16.h", "cuda_bf16.h": "musa_bf16.h", - # ========================================================================= - # Additional CUDA runtime functions - # ========================================================================= + # Additional CUDA runtime function mappings. "cudaHostAlloc": "musaHostAlloc", "cudaHostFree": "musaHostFree", "cudaMallocHost": "musaMallocHost", @@ -334,7 +296,7 @@ "cudaMemGetInfo": "musaMemGetInfo", "cudaMemPrefetchAsync": "musaMemPrefetchAsync", "cudaPointerGetAttributes": "musaPointerGetAttributes", - # Stream flags and types + # Stream flags and types. "cudaStreamDefault": "musaStreamDefault", "cudaStreamNonBlocking": "musaStreamNonBlocking", "cudaStreamCreateWithFlags": "musaStreamCreateWithFlags", @@ -342,13 +304,13 @@ "cudaStreamQuery": "musaStreamQuery", "cudaStreamGetPriority": "musaStreamGetPriority", "cudaStreamGetFlags": "musaStreamGetFlags", - # Event flags + # Event flags. "cudaEventDefault": "musaEventDefault", "cudaEventBlockingSync": "musaEventBlockingSync", "cudaEventDisableTiming": "musaEventDisableTiming", "cudaEventCreateWithFlags": "musaEventCreateWithFlags", "cudaEventQuery": "musaEventQuery", - # Memory flags + # Memory flags. "cudaHostAllocDefault": "musaHostAllocDefault", "cudaHostAllocPortable": "musaHostAllocPortable", "cudaHostAllocMapped": "musaHostAllocMapped", @@ -356,9 +318,7 @@ "cudaMemoryTypeHost": "musaMemoryTypeHost", "cudaMemoryTypeDevice": "musaMemoryTypeDevice", "cudaMemoryTypeManaged": "musaMemoryTypeManaged", - # ========================================================================= - # Device management - # ========================================================================= + # Device management mappings. "cudaDeviceReset": "musaDeviceReset", "cudaDeviceSetCacheConfig": "musaDeviceSetCacheConfig", "cudaDeviceGetCacheConfig": "musaDeviceGetCacheConfig", @@ -369,27 +329,23 @@ "cudaDeviceCanAccessPeer": "musaDeviceCanAccessPeer", "cudaDeviceEnablePeerAccess": "musaDeviceEnablePeerAccess", "cudaDeviceDisablePeerAccess": "musaDeviceDisablePeerAccess", - # Occupancy + # Occupancy mappings. "cudaOccupancyMaxActiveBlocksPerMultiprocessor": "musaOccupancyMaxActiveBlocksPerMultiprocessor", "cudaOccupancyMaxPotentialBlockSize": "musaOccupancyMaxPotentialBlockSize", - # Device properties + # Device property mappings. "cudaDeviceProp": "musaDeviceProp", "cudaFuncAttributes": "musaFuncAttributes", "cudaFuncGetAttributes": "musaFuncGetAttributes", "cudaFuncSetAttribute": "musaFuncSetAttribute", "cudaFuncSetCacheConfig": "musaFuncSetCacheConfig", - # ========================================================================= - # CUDA texture/surface - # ========================================================================= + # CUDA texture and surface mappings. "cudaTextureObject_t": "musaTextureObject_t", "cudaSurfaceObject_t": "musaSurfaceObject_t", "cudaCreateTextureObject": "musaCreateTextureObject", "cudaDestroyTextureObject": "musaDestroyTextureObject", "cudaCreateSurfaceObject": "musaCreateSurfaceObject", "cudaDestroySurfaceObject": "musaDestroySurfaceObject", - # ========================================================================= - # Additional cuBLAS functions - # ========================================================================= + # Additional cuBLAS function mappings. "cublasSetMathMode": "mublasSetMathMode", "cublasGetMathMode": "mublasGetMathMode", "CUBLAS_DEFAULT_MATH": "MUBLAS_DEFAULT_MATH", @@ -398,9 +354,7 @@ "cublasLtDestroy": "mublasLtDestroy", "cublasLtHandle_t": "mublasLtHandle_t", "cublasLtMatmul": "mublasLtMatmul", - # ========================================================================= - # Additional cuDNN functions - # ========================================================================= + # Additional cuDNN function mappings. "cudnnStatus_t": "mudnnStatus_t", "cudnnSetStream": "mudnnSetStream", "cudnnGetStream": "mudnnGetStream", @@ -415,9 +369,7 @@ "cudnnDestroyTensorDescriptor": "mudnnDestroyTensorDescriptor", "cudnnSetTensor4dDescriptor": "mudnnSetTensor4dDescriptor", "cudnnSetTensorNdDescriptor": "mudnnSetTensorNdDescriptor", - # ========================================================================= - # Additional NCCL functions - # ========================================================================= + # Additional NCCL function mappings. "ncclCommInitRank": "mcclCommInitRank", "ncclCommInitAll": "mcclCommInitAll", "ncclCommDestroy": "mcclCommDestroy", @@ -435,9 +387,7 @@ "ncclRecv": "mcclRecv", "ncclGetUniqueId": "mcclGetUniqueId", "ncclUniqueId": "mcclUniqueId", - # ========================================================================= - # CUDA intrinsics and math functions (no mapping needed) - # ========================================================================= + # CUDA intrinsics and math functions intentionally have no mappings. # Math intrinsics: __shfl_sync, __shfl_xor_sync, __shfl_up_sync, __shfl_down_sync, # __ballot_sync, __any_sync, __all_sync, __syncthreads, __syncwarp, # __threadfence, __threadfence_block, __threadfence_system @@ -446,16 +396,12 @@ # Math functions: __float2half, __half2float, __float2half_rn, __float22half2_rn, # __half22float2, __hadd, __hsub, __hmul, __hdiv, __hfma, # __hadd2, __hsub2, __hmul2, __hfma2 - # ========================================================================= - # Common macros - # ========================================================================= + # Common macro mappings. "CUDA_KERNEL_LOOP": "MUSA_KERNEL_LOOP", "CUDA_1D_KERNEL_LOOP": "MUSA_1D_KERNEL_LOOP", "CUDA_2D_KERNEL_LOOP": "MUSA_2D_KERNEL_LOOP", "CUDA_NUM_THREADS": "MUSA_NUM_THREADS", - # ========================================================================= - # PyTorch C++ API - # ========================================================================= + # PyTorch C++ API mappings. "torch::cuda::getCurrentCUDAStream": "torch::musa::getCurrentMUSAStream", "torch::cuda::getDefaultCUDAStream": "torch::musa::getDefaultMUSAStream", "torch::cuda::getStreamFromPool": "torch::musa::getStreamFromPool", @@ -463,10 +409,8 @@ "cudaDeviceIndex": "musaDeviceIndex", "CUDADeviceIndex": "MUSADeviceIndex", "getCurrentCUDABlasHandle": "getCurrentMUSABlasHandle", - # ========================================================================= - # CUDA driver API -> MUSA driver API - # ========================================================================= - # Types (alphabetical) + # CUDA driver API to MUSA driver API mappings. + # Driver API types. "CUcontext": "MUcontext", "CUdevice": "MUdevice", "CUdeviceptr": "MUdeviceptr", @@ -478,7 +422,7 @@ "CUmodule": "MUmodule", "CUresult": "MUresult", "CUstream": "MUstream", - # Functions (alphabetical) + # Driver API functions. "cuCtxGetCurrent": "muCtxGetCurrent", "cuCtxSetCurrent": "muCtxSetCurrent", "cuDeviceGet": "muDeviceGet", @@ -496,7 +440,7 @@ "cuMemSetAccess": "muMemSetAccess", "cuMemUnmap": "muMemUnmap", "cuPointerGetAttribute": "muPointerGetAttribute", - # Constants - memory allocation + # Driver API memory allocation constants. "CU_MEM_ACCESS_FLAGS_PROT_READWRITE": "MU_MEM_ACCESS_FLAGS_PROT_READWRITE", "CU_MEM_ALLOC_GRANULARITY_MINIMUM": "MU_MEM_ALLOC_GRANULARITY_MINIMUM", "CU_MEM_ALLOCATION_COMP_NONE": "MU_MEM_ALLOCATION_COMP_NONE", @@ -506,48 +450,37 @@ "CU_MEM_LOCATION_TYPE_DEVICE": "MU_MEM_LOCATION_TYPE_DEVICE", "CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED": "MU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_MUSA_VMM_SUPPORTED", "CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED": "MU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED", - # Constants - pointer attributes + # Driver API pointer attribute constants. "CU_POINTER_ATTRIBUTE_CONTEXT": "MU_POINTER_ATTRIBUTE_CONTEXT", "CU_POINTER_ATTRIBUTE_DEVICE_POINTER": "MU_POINTER_ATTRIBUTE_DEVICE_POINTER", "CU_POINTER_ATTRIBUTE_HOST_POINTER": "MU_POINTER_ATTRIBUTE_HOST_POINTER", "CU_POINTER_ATTRIBUTE_MEMORY_TYPE": "MU_POINTER_ATTRIBUTE_MEMORY_TYPE", "CU_POINTER_ATTRIBUTE_RANGE_SIZE": "MU_POINTER_ATTRIBUTE_RANGE_SIZE", "CU_POINTER_ATTRIBUTE_RANGE_START_ADDR": "MU_POINTER_ATTRIBUTE_RANGE_START_ADDR", - # Error codes + # Driver API error codes. "CUDA_ERROR_INVALID_VALUE": "MUSA_ERROR_INVALID_VALUE", "CUDA_ERROR_NOT_INITIALIZED": "MUSA_ERROR_NOT_INITIALIZED", "CUDA_ERROR_NOT_PERMITTED": "MUSA_ERROR_NOT_PERMITTED", "CUDA_ERROR_NOT_SUPPORTED": "MUSA_ERROR_NOT_SUPPORTED", "CUDA_ERROR_OUT_OF_MEMORY": "MUSA_ERROR_OUT_OF_MEMORY", "CUDA_SUCCESS": "MUSA_SUCCESS", - # ========================================================================= - # THC headers - # ========================================================================= + # THC header mappings. "#include ": "#include ", - # ========================================================================= - # MCC compiler fixes - # Template keyword required for dependent template calls in mcc - # ========================================================================= + # MCC compiler fixes for dependent template calls. ".FlagHeads": ".template FlagHeads", ".InclusiveSum": ".template InclusiveSum", ".Reduce": ".template Reduce", ".Sum": ".template Sum", "::cast": "::template cast", "SCHEDULER::execute": "SCHEDULER::template execute", - # ========================================================================= - # CUDA launch attributes - # ========================================================================= + # CUDA launch attribute mappings. "cudaLaunchAttribute": "musaLaunchAttribute", "cudaLaunchAttributeProgrammaticStreamSerialization": "musaLaunchAttributeIgnore", "cudaLaunchConfig_t": "musaLaunchConfig_t", - # ========================================================================= - # FlashInfer specific mappings - # ========================================================================= + # FlashInfer-specific mappings. ".is_cuda()": ".is_privateuseone()", "->philox_cuda_state": "->philox_musa_state", - # ========================================================================= - # CUDA arch guards - # ========================================================================= + # CUDA arch guard mappings. "__CUDA_ARCH__ >= 800": "__MUSA_ARCH__ >= 220", "(__CUDA_ARCH__ < 800)": "(__MUSA_ARCH__ < 220)", "(__CUDA_ARCH__ >= 900)": "(__MUSA_ARCH__ >= 310)", @@ -555,12 +488,9 @@ "#include ": "#include ", "#include ": "#include ", "compute_capacity.first >= 8": "compute_capacity.first >= 3", - # ========================================================================= - # FlashInfer math functions - # Replace math.cuh with MUSA fast math intrinsics - # ========================================================================= + # FlashInfer math function mappings. '#include "math.cuh"': """ -// MUSA fast math intrinsics (replacing flashinfer::math functions) +// MUSA fast math intrinsics replacing flashinfer::math functions. __device__ __forceinline__ float fast_rsqrtf(float x) { return __frsqrt_rn(x); } __device__ __forceinline__ float fast_rcp(float x) { return __frcp_rn(x); } """, @@ -568,19 +498,16 @@ "math::rsqrt(smem[0] / float(d) + eps);": "fast_rsqrtf(smem[0] / float(d) + eps);", "math::ptx_rcp(max(sum_low, 1e-8));": "fast_rcp(max(sum_low, 1e-8));", "math::ptx_rcp(denom);": "fast_rcp(denom);", - # PTX log2/exp2 -> standard math functions + # PTX log2/exp2 mappings. "math::ptx_log2": "log2f", "math::ptx_exp2": "exp2f", - # ========================================================================= - # PTX assembly removal for MUSA - # MUSA doesn't support NVIDIA PTX assembly. We use "if(0)" to skip all - # asm volatile blocks (works for both single-line and multi-line cases). - # The compiler will optimize away the dead code. - # ========================================================================= + # PTX assembly removal for MUSA. + # + # MUSA does not support NVIDIA PTX assembly. Prefixing with ``if(0)`` skips + # both single-line and multi-line asm blocks; the compiler removes the dead + # code afterward. "asm volatile": "if(0) asm volatile", - # ========================================================================= - # MUSA compiler workarounds - # ========================================================================= - # __restrict__ in struct members causes copy issues + # MUSA compiler workarounds. + # ``__restrict__`` in struct members causes copy issues. "const void* __restrict__ ptrs[8]": "const void* ptrs[8]", } diff --git a/src/torchada/_patch.py b/src/torchada/_patch.py index 3113e34..043cbfd 100644 --- a/src/torchada/_patch.py +++ b/src/torchada/_patch.py @@ -1,24 +1,24 @@ """ -Automatic patching module for torchada. +Automatic patch orchestration for torchada. This module patches PyTorch to automatically translate 'cuda' device strings to 'musa' when running on Moore Threads hardware. Usage: - import torchada # This applies all patches automatically + import torchada # Apply all patches automatically. import torch - # Then use torch.cuda as normal - it will work on MUSA + # Use torch.cuda APIs normally; they resolve to MUSA on MUSA platforms. torch.cuda.is_available() x = torch.randn(3, 3).cuda() from torch.cuda.amp import autocast, GradScaler - # Distributed training with NCCL also works transparently + # Distributed training with NCCL resolves to MCCL on MUSA platforms. import torch.distributed as dist - dist.init_process_group(backend="nccl") # Uses MCCL on MUSA + dist.init_process_group(backend="nccl") # Uses MCCL on MUSA. - # CUDA Graphs work transparently - g = torch.cuda.CUDAGraph() # Uses MUSAGraph on MUSA + # CUDA graph APIs resolve to MUSA graph APIs on MUSA platforms. + g = torch.cuda.CUDAGraph() # Uses MUSAGraph on MUSA. """ import functools @@ -26,17 +26,55 @@ import sys import warnings from types import ModuleType -from typing import Any, Callable, List, Optional +from typing import Callable, List, Optional import torch +from . import _accelerator_compat as _accelerator_compat +from . import _ctypes_compat as _ctypes_compat +from . import _device_compat as _device_compat +from ._accelerator_compat import patch_torch_accelerator from ._cpp_ops import get_module +from ._ctypes_compat import patch_ctypes_cdll +from ._cuda_compat import ( + _CudaModuleWrapper, + install_cuda_memory_compat, + install_cuda_module_aliases, + install_cuda_public_api_shims, +) +from ._device_compat import ( + _FACTORY_FUNCTIONS, + _translate_device, + _wrap_factory_function, + _wrap_module_cuda, + _wrap_tensor_cuda, + _wrap_to_method, + patch_torch_device, + patch_torch_generator, +) from ._platform import is_musa_platform _patched = False _original_init_process_group = None -# Registry for patch functions +_DYNAMIC_COMPAT_ATTRS = { + "_original_torch_device": _device_compat.get_original_torch_device, + "_original_torch_generator": _device_compat.get_original_torch_generator, + "_original_c_generator": _device_compat.get_original_c_generator, + "_original_ctypes_CDLL": _ctypes_compat.get_original_ctypes_cdll, + "_original_torch_accelerator": _accelerator_compat.get_original_torch_accelerator, +} + + +def __getattr__(name: str): + """Expose moved compatibility state for existing internal imports/tests.""" + getter = _DYNAMIC_COMPAT_ATTRS.get(name) + if getter is not None: + return getter() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# Patch registry. _patch_registry: List[Callable[[], None]] = [] @@ -51,7 +89,7 @@ def patch_function(func: Callable[[], None]) -> Callable[[], None]: Usage: @patch_function def _patch_something(): - # patching logic + # Patching logic. pass The decorated function will be called by apply_patches() in registration order. @@ -71,15 +109,15 @@ def requires_import(*module_names: str) -> Callable[[Callable], Callable]: @patch_function @requires_import('torch_musa') def _patch_something(): - # This only runs if torch_musa is importable + # This only runs if torch_musa is importable. import torch_musa - # ... patching logic + # Patching logic. @patch_function @requires_import('torch._inductor.autotune_process') def _patch_autotune(): import torch._inductor.autotune_process as ap - # ... patching logic + # Patching logic. Args: *module_names: Variable number of module names to check for importability @@ -103,13 +141,6 @@ def wrapper(*args, **kwargs): return decorator -# Cache for translated device strings - avoids repeated string operations -_device_str_cache = {} - -# Cache for is_musa_platform result - computed once on first call -_is_musa_platform_cached = None - - def _has_param(func: Callable, param_name: str) -> bool: """ Check if a function has a specific parameter in its signature. @@ -128,178 +159,6 @@ def _has_param(func: Callable, param_name: str) -> bool: return False -def _translate_device(device: Any) -> Any: - """ - Translate 'cuda' device references to 'musa' on MUSA platform. - - Args: - device: Device specification (string, torch.device, int, or None) - - Returns: - Translated device specification - - Performance: Platform check and string translations are cached. - """ - global _is_musa_platform_cached - - # Cache the platform check result (computed once) - if _is_musa_platform_cached is None: - _is_musa_platform_cached = is_musa_platform() - - if not _is_musa_platform_cached: - return device - - if device is None: - return device - - if isinstance(device, str): - # Check cache first for common strings - if device in _device_str_cache: - return _device_str_cache[device] - - # Handle 'cuda', 'cuda:0', 'cuda:1', etc. - if device == "cuda" or device.startswith("cuda:"): - result = device.replace("cuda", "musa") - _device_str_cache[device] = result - return result - # Cache non-cuda strings too to avoid repeated startswith checks - _device_str_cache[device] = device - return device - - if isinstance(device, torch.device): - if device.type == "cuda": - return torch.device("musa", device.index) - return device - - # For integer device IDs, keep as-is (context determines device type) - return device - - -def _wrap_to_method(original_to: Callable) -> Callable: - """Wrap tensor.to() to translate device strings.""" - - @functools.wraps(original_to) - def wrapped_to(self, *args, **kwargs): - # Translate device in positional args - if args and len(args) >= 1: - first_arg = args[0] - # Check if first arg looks like a device - if isinstance(first_arg, (str, torch.device)): - args = (_translate_device(first_arg),) + args[1:] - elif isinstance(first_arg, torch.dtype): - # .to(dtype) case, check for device in kwargs or second arg - if len(args) >= 2: - args = (first_arg, _translate_device(args[1])) + args[2:] - - # Translate device in keyword args - if "device" in kwargs: - kwargs["device"] = _translate_device(kwargs["device"]) - - return original_to(self, *args, **kwargs) - - return wrapped_to - - -def _wrap_tensor_cuda(original_cuda: Callable) -> Callable: - """Wrap tensor.cuda() to use musa on MUSA platform.""" - # Cache platform check at wrapper creation time - _is_musa = is_musa_platform() - - @functools.wraps(original_cuda) - def wrapped_cuda(self, device=None, non_blocking=False): - if _is_musa: - # Use .musa() instead - if hasattr(self, "musa"): - return self.musa(device=device, non_blocking=non_blocking) - else: - # Fallback to .to() - target_device = f"musa:{device}" if device is not None else "musa" - return self.to(target_device, non_blocking=non_blocking) - return original_cuda(self, device=device, non_blocking=non_blocking) - - return wrapped_cuda - - -def _wrap_module_cuda(original_cuda: Callable) -> Callable: - """Wrap nn.Module.cuda() to use musa on MUSA platform.""" - # Cache platform check at wrapper creation time - _is_musa = is_musa_platform() - - @functools.wraps(original_cuda) - def wrapped_cuda(self, device=None): - if _is_musa: - if hasattr(self, "musa"): - return self.musa(device=device) - else: - target_device = f"musa:{device}" if device is not None else "musa" - return self.to(target_device) - return original_cuda(self, device=device) - - return wrapped_cuda - - -_original_torch_device = None - - -class _DeviceFactoryMeta(type): - """Metaclass to make isinstance(x, torch.device) work with our factory.""" - - def __instancecheck__(cls, instance): - if _original_torch_device is not None: - return isinstance(instance, _original_torch_device) - return False - - def __subclasscheck__(cls, subclass): - if _original_torch_device is not None: - return issubclass(subclass, _original_torch_device) - return False - - -class DeviceFactoryWrapper(metaclass=_DeviceFactoryMeta): - """ - A wrapper class that acts as torch.device but translates cuda to musa. - - Uses a metaclass to properly handle isinstance() checks. - - Supports all calling conventions of torch.device: - torch.device("cuda:0") - torch.device("cuda", 0) - torch.device(type="cuda", index=0) - torch.device(device="cuda:0") - """ - - _original = None - - def __new__(cls, device=None, index=None, *, type=None): - original = cls._original - if original is None: - raise RuntimeError("DeviceFactoryWrapper not initialized") - - # Handle 'type' keyword argument (alias for device in original torch.device) - if type is not None: - device = type - - # Handle the case where device is already a torch.device - if isinstance(device, original): - if device.type == "cuda": - index = device.index if index is None else index - device = "musa" - else: - return device - - # Handle string device - if isinstance(device, str): - device = _translate_device(device) - - # Create the actual device - if index is not None: - return original(device, index) - elif device is not None: - return original(device) - else: - return original() - - @patch_function def _patch_torch_device(): """ @@ -307,52 +166,7 @@ def _patch_torch_device(): This ensures that torch.device("cuda:0") creates a musa device when on MUSA. """ - global _original_torch_device - - if _original_torch_device is not None: - return # Already patched - - _original_torch_device = torch.device - DeviceFactoryWrapper._original = _original_torch_device - - # Replace torch.device with our wrapper - torch.device = DeviceFactoryWrapper - - -# Store original torch.Generator for patching -_original_torch_generator = None -# Store the underlying C Generator class for isinstance checks -_original_c_generator = None - - -class _GeneratorMeta(type): - """Metaclass that properly implements __instancecheck__ for isinstance() to work.""" - - def __instancecheck__(cls, instance): - if _original_c_generator is not None: - return isinstance(instance, _original_c_generator) - return False - - def __subclasscheck__(cls, subclass): - if _original_c_generator is not None: - if subclass is _original_c_generator: - return True - return super().__subclasscheck__(subclass) - - -class GeneratorWrapper(metaclass=_GeneratorMeta): - """Wrapper for torch.Generator that translates cuda -> musa.""" - - _original = None - - def __new__(cls, device=None): - original = cls._original - if original is None: - raise RuntimeError("GeneratorWrapper not initialized") - # Translate device if needed - if device is not None: - device = _translate_device(device) - return original(device=device) + patch_torch_device(torch) @patch_function @@ -366,26 +180,10 @@ def _patch_torch_generator(): Uses a metaclass to properly implement __instancecheck__ so that isinstance(gen, torch.Generator) works correctly. """ - global _original_torch_generator, _original_c_generator - - if _original_torch_generator is not None: - return # Already patched - - _original_torch_generator = torch.Generator - # Get the underlying C Generator class for isinstance checks - # torch_musa may have already wrapped torch.Generator, but instances are still - # of type torch._C.Generator - _original_c_generator = torch._C.Generator - - GeneratorWrapper._original = _original_torch_generator + patch_torch_generator(torch) - # Copy over doc but keep __module__ as torchada._patch so pickle can find the class - GeneratorWrapper.__doc__ = _original_torch_generator.__doc__ - torch.Generator = GeneratorWrapper - - -# Store original graph class for patching +# Saved graph class used by the graph context-manager patch. _original_graph_class = None @@ -400,9 +198,9 @@ def _patch_graph_context_manager(): global _original_graph_class if _original_graph_class is not None: - return # Already patched + return - # Get the graph class from torch.cuda (which is torch.musa after patching) + # Read from torch.cuda after module redirection so this is torch.musa on MUSA. if not hasattr(torch.cuda, "graph"): return @@ -411,7 +209,7 @@ def _patch_graph_context_manager(): class GraphWrapper: """Wrapper for torch.cuda.graph that accepts cuda_graph= keyword argument.""" - # Preserve class attributes + # Preserve class attributes used by callers. default_capture_stream = None def __init__( @@ -421,14 +219,13 @@ def __init__( stream=None, capture_error_mode: str = "global", *, - musa_graph=None, # Also accept musa_graph for compatibility + musa_graph=None, # Also accept musa_graph for compatibility. ): - # Allow either cuda_graph= or musa_graph= or positional argument + # Accept CUDA and MUSA keyword spellings. graph_obj = cuda_graph if cuda_graph is not None else musa_graph if graph_obj is None: raise TypeError("graph() missing required argument: 'cuda_graph'") - # Create the original graph instance self._wrapped = _original_graph_class( graph_obj, pool=pool, @@ -442,229 +239,18 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): return self._wrapped.__exit__(exc_type, exc_value, traceback) - # Copy over class attributes and docstring + # Preserve metadata for introspection. GraphWrapper.__doc__ = _original_graph_class.__doc__ GraphWrapper.__module__ = _original_graph_class.__module__ - # Replace torch.cuda.graph with our wrapper torch.cuda.graph = GraphWrapper - # Also update torch.musa.graph if it exists + # Keep the backend module consistent when it exposes graph directly. if hasattr(torch, "musa") and hasattr(torch.musa, "graph"): torch.musa.graph = GraphWrapper -def _wrap_factory_function(original_fn: Callable) -> Callable: - """Wrap tensor factory functions (empty, zeros, ones, etc.) to translate device.""" - - @functools.wraps(original_fn) - def wrapped_fn(*args, **kwargs): - if "device" in kwargs: - kwargs["device"] = _translate_device(kwargs["device"]) - return original_fn(*args, **kwargs) - - return wrapped_fn - - -# List of torch factory functions that accept a device argument -_FACTORY_FUNCTIONS = [ - # Basic tensor creation - "tensor", - "as_tensor", - "asarray", - # Uninitialized/initialized tensors - "empty", - "zeros", - "ones", - "full", - # Random tensors - "rand", - "randn", - "randint", - "randperm", - "normal", - # Sequences - "arange", - "range", - "linspace", - "logspace", - # Special tensors - "eye", - "empty_strided", - "empty_permuted", - # From file - "from_file", - # Like variants - "empty_like", - "zeros_like", - "ones_like", - "full_like", - "rand_like", - "randn_like", - "randint_like", - # Sparse tensors - "sparse_coo_tensor", - "sparse_csr_tensor", - "sparse_csc_tensor", - "sparse_bsr_tensor", - "sparse_bsc_tensor", - "sparse_compressed_tensor", - # Index tensors - "tril_indices", - "triu_indices", - # Window functions - "bartlett_window", - "blackman_window", - "hamming_window", - "hann_window", - "kaiser_window", -] - - -class _CudartWrapper: - """ - Wrapper for CUDA runtime that translates calls to MUSA runtime. - - This allows code like `torch.cuda.cudart().cudaHostRegister(...)` to work - on MUSA by translating to `torch_musa.musart().musaHostRegister(...)`. - - Performance optimization: Resolved attributes are cached in __dict__ to avoid - repeated __getattr__ calls. - """ - - # Mapping from CUDA runtime function names to MUSA equivalents - _CUDA_TO_MUSA = { - "cudaHostRegister": "musaHostRegister", - "cudaHostUnregister": "musaHostUnregister", - "cudaMemGetInfo": "musaMemGetInfo", - "cudaGetErrorString": "musaGetErrorString", - "cudaStreamCreate": "musaStreamCreate", - "cudaStreamDestroy": "musaStreamDestroy", - } - - def __init__(self, musart_module): - self._musart = musart_module - - def __getattr__(self, name): - # Translate CUDA runtime function names to MUSA equivalents - if name in self._CUDA_TO_MUSA: - musa_name = self._CUDA_TO_MUSA[name] - value = getattr(self._musart, musa_name) - # Cache in __dict__ for faster subsequent access - object.__setattr__(self, name, value) - return value - - # Try direct access (for any functions with same name) - if hasattr(self._musart, name): - value = getattr(self._musart, name) - # Cache in __dict__ for faster subsequent access - object.__setattr__(self, name, value) - return value - - raise AttributeError(f"CUDA runtime has no attribute '{name}'") - - -class _CudaModuleWrapper(ModuleType): - """ - A wrapper module that redirects torch.cuda to torch.musa, - but keeps certain attributes (like is_available) pointing to the original. - - This allows downstream projects to detect MUSA platform using: - torch.cuda.is_available() # Returns False on MUSA (original behavior) - While still using torch.cuda.* APIs that redirect to torch.musa. - - Performance optimization: Resolved attributes are cached in __dict__ to avoid - repeated __getattr__ calls. This reduces overhead from ~800ns to ~50ns for - cached attributes. - """ - - # Attributes that should NOT be redirected to torch.musa - _NO_REDIRECT = {"is_available"} - - # Special attribute mappings for attributes not at top level of torch_musa - # Maps attribute name -> dot-separated path within torch_musa - _SPECIAL_ATTRS = { - "StreamContext": "core.stream.StreamContext", - } - - # Attribute name remappings (CUDA name -> MUSA name) - # For CUDA-specific APIs that have different names in MUSA - _REMAP_ATTRS = { - "_device_count_nvml": "device_count", # NVML is NVIDIA-specific - } - - # Attributes that should NOT be cached (functions that may return different values) - # Most functions are safe to cache since they're module-level functions - _NO_CACHE = { - # These are typically not called in hot paths anyway - } - - def __init__(self, original_cuda, musa_module): - super().__init__("torch.cuda") - self._original_cuda = original_cuda - self._musa_module = musa_module - self._cudart_wrapper = None - - def cudart(self): - """ - Return a CUDA runtime wrapper that translates to MUSA runtime. - - This allows code like `torch.cuda.cudart().cudaHostRegister(...)` to work - on MUSA by translating to the equivalent MUSA runtime calls. - """ - if self._cudart_wrapper is None: - if hasattr(self._musa_module, "musart"): - musart_module = self._musa_module.musart() - self._cudart_wrapper = _CudartWrapper(musart_module) - else: - # Fallback to original if musart not available - return self._original_cuda.cudart() - return self._cudart_wrapper - - def __getattr__(self, name): - # Keep original is_available behavior - if name in self._NO_REDIRECT: - value = getattr(self._original_cuda, name) - # Cache in __dict__ for faster subsequent access - if name not in self._NO_CACHE: - object.__setattr__(self, name, value) - return value - - # Handle special attributes that need nested lookup - if name in self._SPECIAL_ATTRS: - obj = self._musa_module - for part in self._SPECIAL_ATTRS[name].split("."): - obj = getattr(obj, part) - # Cache the resolved value - if name not in self._NO_CACHE: - object.__setattr__(self, name, obj) - return obj - - # Handle attribute name remapping (CUDA-specific names -> MUSA equivalents) - if name in self._REMAP_ATTRS: - value = getattr(self._musa_module, self._REMAP_ATTRS[name]) - # Cache the resolved value - if name not in self._NO_CACHE: - object.__setattr__(self, name, value) - return value - - # Redirect everything else to torch.musa - value = getattr(self._musa_module, name) - # Cache the resolved value for faster subsequent access - # This is safe because module attributes don't change at runtime - if name not in self._NO_CACHE: - object.__setattr__(self, name, value) - return value - - def __dir__(self): - # Combine attributes from both modules - attrs = set(dir(self._musa_module)) - attrs.update(self._NO_REDIRECT) - attrs.add("cudart") - return list(attrs) - - -# Store original torch.cuda module before patching +# Saved original torch.cuda module. _original_torch_cuda = None @@ -681,111 +267,24 @@ def _patch_torch_cuda_module(): """ global _original_torch_cuda - # torch_musa registers itself as torch.musa when imported - # Now patch torch.cuda to point to torch.musa (which is torch_musa) + # torch_musa registers itself as torch.musa when imported. if hasattr(torch, "musa"): - # Save original torch.cuda before patching if _original_torch_cuda is None: _original_torch_cuda = torch.cuda - # Create wrapper module that redirects most things to torch.musa - # but keeps is_available pointing to the original + # Preserve CUDA-only detection APIs while redirecting the rest to MUSA. cuda_wrapper = _CudaModuleWrapper(_original_torch_cuda, torch.musa) - # Replace torch.cuda with our wrapper in sys.modules - # This makes 'from torch.cuda import ...' work + # Keep import statements and attribute access on the same wrapper. sys.modules["torch.cuda"] = cuda_wrapper - - # Also patch torch.cuda attribute directly torch.cuda = cuda_wrapper - # Patch torch.cuda.amp - if hasattr(torch.musa, "amp"): - sys.modules["torch.cuda.amp"] = torch.musa.amp - - # Patch torch.cuda.graphs - MUSAGraph should be accessible as CUDAGraph - if hasattr(torch.musa, "graphs"): - sys.modules["torch.cuda.graphs"] = torch.musa.graphs - - # Add CUDAGraph alias pointing to MUSAGraph - if hasattr(torch.musa, "MUSAGraph") and not hasattr(torch.musa, "CUDAGraph"): - torch.musa.CUDAGraph = torch.musa.MUSAGraph - - # Patch torch.cuda.memory - if hasattr(torch.musa, "memory"): - musa_memory_module = torch.musa.memory - if musa_memory_module is not None: - sys.modules["torch.cuda.memory"] = musa_memory_module - # Add CUDAPluggableAllocator alias pointing to MUSAPluggableAllocator - if hasattr(musa_memory_module, "MUSAPluggableAllocator"): - musa_memory_module.CUDAPluggableAllocator = ( - musa_memory_module.MUSAPluggableAllocator - ) - - # Inject CUDA-compatible memory pool functions from C++ extension - # These functions (_cuda_beginAllocateCurrentThreadToPool, etc.) are - # implemented in torchada's C++ extension to provide CUDA API compatibility - # for torch_musa's memory pool allocator. - cpp_ops_module = get_module() - if cpp_ops_module is not None: - for func_name in [ - "_cuda_beginAllocateCurrentThreadToPool", - "_cuda_endAllocateToPool", - "_cuda_releasePool", - ]: - func = getattr(cpp_ops_module, func_name, None) - if func is not None: - setattr(musa_memory_module, func_name, func) - - # Patch torch.cuda.graph context manager to accept cuda_graph= keyword - # MUSA's graph class uses musa_graph= but CUDA code uses cuda_graph= - _patch_graph_context_manager() - - # Patch torch.cuda.nccl -> torch.musa.mccl - if hasattr(torch.musa, "mccl"): - sys.modules["torch.cuda.nccl"] = torch.musa.mccl - - # Patch torch.cuda.profiler - if hasattr(torch.musa, "profiler"): - sys.modules["torch.cuda.profiler"] = torch.musa.profiler - - # Patch torch.cuda.nvtx - use our stub since MUSA doesn't have nvtx - try: - from .cuda import nvtx as nvtx_stub - - sys.modules["torch.cuda.nvtx"] = nvtx_stub - torch.musa.nvtx = nvtx_stub - except ImportError: - pass - - # Patch torch.cuda.random - use torchada.cuda.random module - if not hasattr(torch.musa, "random"): - try: - from .cuda import random as random_stub - - sys.modules["torch.cuda.random"] = random_stub - torch.musa.random = random_stub - except ImportError: - pass - - # Patch missing _lazy_call from torch_musa.core._lazy_init - # torch_musa only maps _lazy_init but not _lazy_call - # This is needed for code that does: from torch.cuda import _lazy_call - # We add it to torch.musa so _CudaModuleWrapper can redirect it - try: - from torch_musa.core._lazy_init import _lazy_call + install_cuda_module_aliases(torch) + install_cuda_memory_compat(torch, get_module(), _translate_device) - # Only add if not already present (forward compatible with torch_musa fix) - if not hasattr(torch.musa, "_lazy_call"): - torch.musa._lazy_call = _lazy_call - except ImportError: - pass - - # Add _is_compiled to torch_musa if not present - # This is needed for code that checks torch.cuda._is_compiled() - # (e.g., vLLM's CUDA kernel availability checks) - if not hasattr(torch.musa, "_is_compiled"): - torch.musa._is_compiled = lambda: True + # Accept CUDA graph keyword spelling on top of the MUSA graph class. + _patch_graph_context_manager() + install_cuda_public_api_shims(torch, _translate_device) @patch_function @@ -801,7 +300,6 @@ def _patch_distributed_backend(): import torch.distributed as dist if _original_init_process_group is not None: - # Already patched return _original_init_process_group = dist.init_process_group @@ -818,16 +316,16 @@ def patched_init_process_group( pg_options=None, device_id=None, ): - # Translate 'nccl' to 'mccl' on MUSA platform + # Translate NCCL backend requests to MCCL on MUSA. if is_musa_platform() and backend is not None: if backend.lower() == "nccl": backend = "mccl" - # Translate device_id if it's a cuda device + # Translate CUDA device IDs before delegating. if device_id is not None: device_id = _translate_device(device_id) - # Build kwargs for the original function + # Preserve the original signature while allowing version-specific args. kwargs = { "backend": backend, "init_method": init_method, @@ -845,10 +343,10 @@ def patched_init_process_group( dist.init_process_group = patched_init_process_group - # Also patch new_group to translate 'nccl' to 'mccl' + # Patch new_group with the same backend and device translation. original_new_group = dist.new_group - # Cache the check for device_id support (added in torch 2.6) + # Cache device_id support because it was added in PyTorch 2.6. _new_group_has_device_id = _has_param(original_new_group, "device_id") @functools.wraps(original_new_group) @@ -861,12 +359,12 @@ def patched_new_group( group_desc=None, device_id=None, ): - # Translate 'nccl' to 'mccl' on MUSA platform + # Translate NCCL backend requests to MCCL on MUSA. if is_musa_platform() and backend is not None: if isinstance(backend, str) and backend.lower() == "nccl": backend = "mccl" - # Build kwargs for the original function + # Preserve the original signature while allowing version-specific args. kwargs = { "ranks": ranks, "backend": backend, @@ -875,7 +373,7 @@ def patched_new_group( "group_desc": group_desc, } - # Translate device_id if it's a cuda device (only if supported by torch version) + # Translate CUDA device IDs only when the installed PyTorch accepts them. if device_id is not None and _new_group_has_device_id: kwargs["device_id"] = _translate_device(device_id) @@ -898,25 +396,22 @@ def _patch_tensor_is_cuda(): Performance: Uses try/except with direct attribute access for speed. Benchmarks show getattr(self, 'is_musa', False) is faster than self.device.type. """ - # Store the original is_cuda property (it's a getset_descriptor) + # Keep the descriptor so CUDA tensors retain their native fast path. original_is_cuda = torch.Tensor.is_cuda @property def patched_is_cuda(self): """Return True if tensor is on CUDA or MUSA device.""" - # Check original is_cuda first (fast path for actual CUDA tensors) - # Use direct property access - original_is_cuda is a getset_descriptor + # Use direct descriptor access for actual CUDA tensors. result = original_is_cuda.__get__(self) if result: return True - # Check if tensor is on MUSA device - # Use try/except with direct attribute access - faster than getattr with default + # Direct attribute access is faster than getattr with a default here. try: return self.is_musa except AttributeError: return False - # Replace is_cuda with our patched version torch.Tensor.is_cuda = patched_is_cuda @@ -931,12 +426,11 @@ def _patch_stream_cuda_stream(): """ from torch_musa.core.stream import Stream as MUSAStream - # Add cuda_stream property that returns musa_stream if not hasattr(MUSAStream, "cuda_stream"): @property def cuda_stream(self): - """Return the underlying stream pointer (same as musa_stream).""" + """Return the underlying stream pointer, matching ``musa_stream``.""" return self.musa_stream MUSAStream.cuda_stream = cuda_stream @@ -955,7 +449,7 @@ def _patch_autocast(): class PatchedAutocast(original_autocast): def __init__(self, device_type, *args, **kwargs): - # Translate 'cuda' to 'musa' + # Translate CUDA autocast contexts to MUSA contexts. if device_type == "cuda": device_type = "musa" super().__init__(device_type, *args, **kwargs) @@ -987,7 +481,7 @@ def _translate_activities(activities): translated = [] for activity in activities: if activity == torch.profiler.ProfilerActivity.CUDA: - # On MUSA, use PrivateUse1 instead of CUDA + # MUSA profiler events use PrivateUse1 rather than CUDA. translated.append(torch.profiler.ProfilerActivity.PrivateUse1) else: translated.append(activity) @@ -1024,16 +518,14 @@ def _patch_musa_warnings(): We suppress them using Python's warnings.filterwarnings(). """ - # Suppress autocast dtype warning from torch/amp/autocast_mode.py - # This happens when autocast is used with unsupported dtypes on MUSA + # Suppress autocast dtype warnings for unsupported MUSA dtypes. warnings.filterwarnings( "ignore", message=r"In musa autocast, but the target dtype is not supported.*", category=UserWarning, ) - # Suppress FlashAttention unsupported dimension warning from torch_musa - # This happens when SDP attention is used with unsupported head dimensions + # Suppress FlashAttention dimension warnings for unsupported MUSA head sizes. warnings.filterwarnings( "ignore", message=r"Unsupported qk_head_dim:.*for FlashAttention in MUSA backend.*", @@ -1062,16 +554,16 @@ def _patch_library_impl(): This patch preserves the full original signature including the with_keyset parameter. Example of code that needs this patch: - my_lib.impl(op_name, op_func, "CUDA") # Now works on MUSA! - my_lib.impl(op_name, op_func, "Autograd", with_keyset=True) # Also works! - my_lib.impl(op_name, op_func, "Autograd", with_keyset=True, allow_override=True) # Also works! + my_lib.impl(op_name, op_func, "CUDA") # Works on MUSA. + my_lib.impl(op_name, op_func, "Autograd", with_keyset=True) # Works on MUSA. + my_lib.impl(op_name, op_func, "Autograd", with_keyset=True, allow_override=True) """ if not hasattr(torch, "library") or not hasattr(torch.library, "Library"): return original_impl = torch.library.Library.impl - # Mapping of CUDA dispatch keys to PrivateUse1 equivalents + # CUDA dispatch keys that should register against PrivateUse1 on MUSA. cuda_dispatch_key_map = { "CUDA": "PrivateUse1", "AutogradCUDA": "AutogradPrivateUse1", @@ -1083,7 +575,7 @@ def _patch_library_impl(): } def patched_impl(self, *args, **kwargs): - # Translate CUDA dispatch keys to PrivateUse1 equivalents for MUSA compatibility + # Translate CUDA dispatch keys before registering custom operators. sig = inspect.signature(original_impl) bound = sig.bind(self, *args, **kwargs) bound.apply_defaults() @@ -1116,11 +608,9 @@ def _patch_torch_c_exports(): musac = torch_musa._MUSAC - # List of functions/classes to copy from _MUSAC to torch._C - # These are commonly imported by downstream code + # Common downstream imports that torch_musa exposes under _MUSAC only. _MUSAC_EXPORTS = [ "_storage_Use_Count", - # Add more as needed ] for name in _MUSAC_EXPORTS: @@ -1142,18 +632,15 @@ def _patch_backends_cuda(): if not hasattr(torch, "backends") or not hasattr(torch.backends, "cuda"): return - # Patch is_built() to return True when MUSA is available - # This allows code that checks torch.backends.cuda.is_built() to proceed + # Let CUDA build checks pass when torchada is redirecting CUDA APIs to MUSA. original_is_built = torch.backends.cuda.is_built - # Cache the result since it won't change at runtime + # Cache the result because platform state does not change at runtime. _is_built_cache = {} def patched_is_built(): if "result" not in _is_built_cache: - # On MUSA platform, report as "built" since we redirect cuda->musa. - # Use is_musa_platform() instead of torch.musa.is_available() so this - # works even when no GPU card is present (build-only environments). + # Treat MUSA as CUDA-built even in build-only environments. if is_musa_platform(): _is_built_cache["result"] = True else: @@ -1202,7 +689,6 @@ def patched_setattr(self, name, value): matmul_class.__setattr__ = patched_setattr - @patch_function @requires_import("torchada.utils.cpp_extension", "torch.utils.cpp_extension") def _patch_cpp_extension(): @@ -1222,17 +708,16 @@ def _patch_cpp_extension(): from .utils import cpp_extension as torchada_cpp_ext - # Patch the key classes and functions + # Patch the key classes and constants. torch_cpp_ext.CUDAExtension = torchada_cpp_ext.CUDAExtension torch_cpp_ext.BuildExtension = torchada_cpp_ext.BuildExtension torch_cpp_ext.CUDA_HOME = torchada_cpp_ext.CUDA_HOME - # Patch include_paths and library_paths to handle both old and new signatures - # and to correctly translate "cuda" to MUSA on MUSA platform + # Delegate include/library path handling to the torchada compatibility layer. torch_cpp_ext.include_paths = torchada_cpp_ext.include_paths torch_cpp_ext.library_paths = torchada_cpp_ext.library_paths - # Also update sys.modules entry + # Keep future imports on the patched module object. sys.modules["torch.utils.cpp_extension"] = torch_cpp_ext @@ -1249,7 +734,7 @@ def _patch_autotune_process(): """ import torch._inductor.autotune_process as autotune_process - # Patch the CUDA_VISIBLE_DEVICES constant to use MUSA_VISIBLE_DEVICES + # Use the MUSA visibility environment variable in autotune subprocesses. if hasattr(autotune_process, "CUDA_VISIBLE_DEVICES"): autotune_process.CUDA_VISIBLE_DEVICES = "MUSA_VISIBLE_DEVICES" @@ -1297,306 +782,22 @@ def _patch_flash_attn(): """ import flash_attn_interface - # Ensure sgl_kernel package exists in sys.modules. - # First try to import the real package; only create a stub if it's truly not installed. + # Prefer the real package; create a stub only when sgl_kernel is absent. if "sgl_kernel" not in sys.modules: try: import sgl_kernel # noqa: F401 except ImportError: sgl_kernel_stub = ModuleType("sgl_kernel") - sgl_kernel_stub.__path__ = [] # Make it a package + sgl_kernel_stub.__path__ = [] # Mark the stub as a package. sgl_kernel_stub.__package__ = "sgl_kernel" sys.modules["sgl_kernel"] = sgl_kernel_stub - # Register flash_attn_interface as sgl_kernel.flash_attn submodule + # Register flash_attn_interface as the sgl_kernel.flash_attn submodule. sgl_kernel = sys.modules["sgl_kernel"] sgl_kernel.flash_attn = flash_attn_interface sys.modules["sgl_kernel.flash_attn"] = flash_attn_interface -class _CDLLWrapper: - """ - Wrapper for ctypes.CDLL that automatically translates CUDA/NCCL function names - to MUSA/MCCL equivalents when accessing library functions. - - This allows code that uses ctypes to load CUDA libraries (libcudart, libnccl) and - access CUDA-named functions to work transparently on MUSA without code changes. - - Example: - # Original code uses CUDA function names: - lib = ctypes.CDLL("libmusart.so") - func = lib.cudaIpcOpenMemHandle # Automatically translates to musaIpcOpenMemHandle - - lib = ctypes.CDLL("libmccl.so") - func = lib.ncclAllReduce # Automatically translates to mcclAllReduce - """ - - # Detect library type from filename patterns - _MUSART_PATTERNS = ("libmusart", "musart.so", "libmusa_runtime") - _MCCL_PATTERNS = ("libmccl", "mccl.so") - _MUBLAS_PATTERNS = ("libmublas", "mublas.so") - _MURAND_PATTERNS = ("libmurand", "murand.so") - - def __init__(self, cdll_instance, lib_path: str): - # Store the original CDLL instance - object.__setattr__(self, "_cdll", cdll_instance) - object.__setattr__(self, "_lib_path", lib_path) - object.__setattr__(self, "_lib_type", self._detect_lib_type(lib_path)) - - def _detect_lib_type(self, lib_path: str) -> str: - """Detect the type of library from its path.""" - lib_path_lower = lib_path.lower() - if any(p in lib_path_lower for p in self._MUSART_PATTERNS): - return "musart" - elif any(p in lib_path_lower for p in self._MCCL_PATTERNS): - return "mccl" - elif any(p in lib_path_lower for p in self._MUBLAS_PATTERNS): - return "mublas" - elif any(p in lib_path_lower for p in self._MURAND_PATTERNS): - return "murand" - return "unknown" - - def _translate_name(self, name: str) -> str: - """Translate CUDA/NCCL function name to MUSA/MCCL equivalent.""" - lib_type = object.__getattribute__(self, "_lib_type") - - if lib_type == "musart": - # cudaXxx -> musaXxx - if name.startswith("cuda"): - return "musa" + name[4:] - elif lib_type == "mccl": - # ncclXxx -> mcclXxx - if name.startswith("nccl"): - return "mccl" + name[4:] - elif lib_type == "mublas": - # cublasXxx -> mublasXxx - if name.startswith("cublas"): - return "mublas" + name[6:] - elif lib_type == "murand": - # curandXxx -> murandXxx - if name.startswith("curand"): - return "murand" + name[6:] - - return name - - def __getattr__(self, name: str): - cdll = object.__getattribute__(self, "_cdll") - translated_name = self._translate_name(name) - value = getattr(cdll, translated_name) - # Cache in __dict__ for faster subsequent access - object.__setattr__(self, name, value) - return value - - def __setattr__(self, name: str, value): - cdll = object.__getattribute__(self, "_cdll") - translated_name = self._translate_name(name) - setattr(cdll, translated_name, value) - - def __getitem__(self, name: str): - cdll = object.__getattribute__(self, "_cdll") - translated_name = self._translate_name(name) - return cdll[translated_name] - - -# Store original ctypes.CDLL for patching -_original_ctypes_CDLL = None - - -class _AcceleratorModuleWrapper(ModuleType): - """ - Wrapper module that extends torch.accelerator with fallbacks to torch.musa. - - torch.accelerator is the unified accelerator abstraction being built up over - successive PyTorch releases. Many APIs scheduled for PyTorch 2.9+ - (e.g. empty_cache, memory_stats, memory_allocated, Stream, Event, - manual_seed, get_device_name, ...) do not yet exist on torch.accelerator - in torch 2.7 / torch_musa, but do exist on torch.musa. This wrapper lets - user code written against the newer unified API work on current MUSA builds - by falling back to torch.musa for any attribute missing from the original - torch.accelerator module. - - Resolution order for attribute access: - 1. Explicit overrides installed by torchada (e.g. patched synchronize, - device_index / stream context managers, and memory APIs that exist - upstream but are broken on MUSA) - 2. The original torch.accelerator module (so existing APIs keep their - real implementations) - 3. torch.musa as a fallback for APIs that have not yet been added to - torch.accelerator upstream, applying _REMAP_ATTRS for APIs whose - torch.musa equivalent has a different name - - Resolved attributes are cached in __dict__ for fast subsequent access, - matching the pattern used by _CudaModuleWrapper. - """ - - # Attribute name remappings (torch.accelerator name -> torch.musa name). - # torch.accelerator uses an *_index / *_idx naming convention introduced in - # newer PyTorch releases, while torch.musa keeps the older torch.cuda style - # without the suffix. When the original torch.accelerator module does not - # expose these names (e.g. older PyTorch builds), the wrapper falls back to - # torch.musa using the remapped name so callers still get a working API. - _REMAP_ATTRS = { - "set_device_index": "set_device", - "set_device_idx": "set_device", - "current_device_index": "current_device", - "current_device_idx": "current_device", - } - - # Special attribute mappings for attributes not at top level of torch_musa. - # Maps attribute name -> dot-separated path within torch_musa. - _SPECIAL_ATTRS = { - "StreamContext": "core.stream.StreamContext", - } - - # Memory APIs that exist on torch.accelerator (PyTorch 2.9+) but internally - # call torch._C._accelerator_* C++ functions which fail on MUSA because the - # MUSA allocator is not a CUDA DeviceAllocator. These are overridden to - # delegate to torch.musa, following the same pattern as synchronize(). - # When an API in this list exists on the original torch.accelerator AND on - # torch.musa, we install an override that prefers torch.musa over the - # upstream implementation. - _MUSA_OVERRIDES = ( - "empty_cache", - "empty_host_cache", - "memory_stats", - "memory_allocated", - "max_memory_allocated", - "memory_reserved", - "max_memory_reserved", - "reset_accumulated_memory_stats", - "reset_peak_memory_stats", - "get_memory_info", - ) - - def __init__(self, original_accel, musa_module): - super().__init__("torch.accelerator") - self._original_accel = original_accel - self._musa_module = musa_module - self._overrides = {} - - # Apply MUSA overrides for memory APIs that exist upstream but are - # broken on MUSA (they route through torch._C._accelerator_* which - # doesn't dispatch to the MUSA allocator). - for name in self._MUSA_OVERRIDES: - if hasattr(original_accel, name) and hasattr(musa_module, name): - self._set_override(name, getattr(musa_module, name)) - - def _set_override(self, name, value): - """Install an override that takes precedence over the wrapped modules.""" - self._overrides[name] = value - object.__setattr__(self, name, value) - - def __getattr__(self, name): - if name in self._overrides: - return self._overrides[name] - try: - value = getattr(self._original_accel, name) - except AttributeError: - # Fall back to torch.musa with several strategies in order: - # 1. Same-name lookup (e.g., empty_cache) - # 2. Special nested attributes (e.g., StreamContext -> core.stream.StreamContext) - # 3. Name remapping (e.g., set_device_index -> set_device) - if hasattr(self._musa_module, name): - value = getattr(self._musa_module, name) - elif name in self._SPECIAL_ATTRS: - obj = self._musa_module - for part in self._SPECIAL_ATTRS[name].split("."): - obj = getattr(obj, part) - value = obj - elif name in self._REMAP_ATTRS: - value = getattr(self._musa_module, self._REMAP_ATTRS[name]) - else: - raise AttributeError(f"module 'torch.accelerator' has no attribute '{name}'") - object.__setattr__(self, name, value) - return value - - def __dir__(self): - attrs = set(dir(self._original_accel)) - attrs.update(dir(self._musa_module)) - attrs.update(self._REMAP_ATTRS.keys()) - attrs.update(self._SPECIAL_ATTRS.keys()) - attrs.update(self._overrides.keys()) - return list(attrs) - - -# Store original torch.accelerator module before patching -_original_torch_accelerator = None - - -def _make_patched_accelerator_synchronize(musa_module): - """Build a torch.accelerator.synchronize replacement that delegates to torch.musa.""" - - def patched_synchronize(device=None): - """ - Patched synchronize that redirects to torch.musa.synchronize(). - - The MUSA backend does not implement synchronization of all streams on a - device, so the default torch.accelerator.synchronize() raises at runtime. - Redirecting to torch.musa.synchronize() restores the expected behavior. - - Args: - device: torch.device, str, int, or None. If None, synchronizes the - current device. - - Raises: - TypeError: If device is not a valid type (torch.device, str, int, or None). - """ - # Validate the device type to catch invalid inputs early - if device is not None and not isinstance(device, (torch.device, str, int)): - raise TypeError( - f"synchronize() expected device to be torch.device, str, int, or None, " - f"but got {type(device).__name__}" - ) - - # torch.musa.synchronize natively handles all valid device types: - # - None: synchronizes the current device - # - int: synchronizes device at that index - # - str: handles both "musa" (current device) and "musa:N" (specific device) - # - torch.device: handles both torch.device("musa") and torch.device("musa:N") - # Delegate directly instead of manually parsing to preserve upstream semantics. - musa_module.synchronize(device) - - return patched_synchronize - - -def _make_accelerator_context_managers(accel_module): - """Build device_index / stream context managers that bind to accel_module.""" - - class device_index: - """Context manager to temporarily set the current device index.""" - - def __init__(self, idx): - self.idx = idx - self.prev_idx = None - - def __enter__(self): - self.prev_idx = accel_module.current_device_index() - accel_module.set_device_index(self.idx) - return self - - def __exit__(self, *args): - if self.prev_idx is not None: - accel_module.set_device_index(self.prev_idx) - - class stream: - """Context manager to temporarily set the current stream.""" - - def __init__(self, stream_obj): - self.stream = stream_obj - self.prev_stream = None - - def __enter__(self): - self.prev_stream = accel_module.current_stream() - accel_module.set_stream(self.stream) - return self - - def __exit__(self, *args): - if self.prev_stream is not None: - accel_module.set_stream(self.prev_stream) - - return device_index, stream - - @patch_function @requires_import("torch_musa", "torch.accelerator") def _patch_torch_accelerator(): @@ -1623,32 +824,8 @@ def _patch_torch_accelerator(): 4. device_index(idx) and stream(s) context managers, which are not yet present on torch.accelerator in torch 2.7. - - TODO(torchada): README.md / README_CN.md claim "the wrapper always prefers - the real torch.accelerator implementation and only falls back to torch.musa - when an attribute is missing". That is no longer accurate after adding the - memory API overrides (point 2 above). Update those documents to describe - the actual resolution order: (1) torchada overrides, (2) real torch.accelerator, - (3) fallback to torch.musa. """ - global _original_torch_accelerator - - import torch.accelerator as accel - - if _original_torch_accelerator is None: - _original_torch_accelerator = accel - - wrapper = _AcceleratorModuleWrapper(_original_torch_accelerator, torch.musa) - - wrapper._set_override("synchronize", _make_patched_accelerator_synchronize(torch.musa)) - device_index_cm, stream_cm = _make_accelerator_context_managers(wrapper) - if not hasattr(_original_torch_accelerator, "device_index"): - wrapper._set_override("device_index", device_index_cm) - if not hasattr(_original_torch_accelerator, "stream"): - wrapper._set_override("stream", stream_cm) - - sys.modules["torch.accelerator"] = wrapper - torch.accelerator = wrapper + patch_torch_accelerator(torch) @patch_function @@ -1669,48 +846,10 @@ def _patch_ctypes_cdll(): Example (in sglang): lib = ctypes.CDLL("libmusart.so") - # This will automatically find musaIpcOpenMemHandle: + # This lookup resolves to musaIpcOpenMemHandle: func = lib.cudaIpcOpenMemHandle """ - import ctypes - - global _original_ctypes_CDLL - - # Only patch once - if _original_ctypes_CDLL is not None: - return - - _original_ctypes_CDLL = ctypes.CDLL - - class PatchedCDLL: - """Patched CDLL that wraps MUSA libraries with function name translation.""" - - def __new__(cls, name, *args, **kwargs): - # Create the original CDLL instance - cdll_instance = _original_ctypes_CDLL(name, *args, **kwargs) - - # Check if this is a MUSA library that needs wrapping - name_str = str(name) if name else "" - if any( - pattern in name_str.lower() - for pattern in ( - "libmusart", - "musart.so", - "libmusa_runtime", - "libmccl", - "mccl.so", - "libmublas", - "mublas.so", - "libmurand", - "murand.so", - ) - ): - return _CDLLWrapper(cdll_instance, name_str) - - # For non-MUSA libraries, return the original CDLL instance - return cdll_instance - - ctypes.CDLL = PatchedCDLL + patch_ctypes_cdll() def apply_patches(): @@ -1756,34 +895,27 @@ def apply_patches(): _patched = True return - # Import torch_musa to ensure it's initialized + # Import torch_musa so torch.musa is registered before applying patches. try: import torch_musa # noqa: F401 except ImportError: _patched = True return - # Apply all registered patch functions - # These are registered via @patch_function decorator in definition order + # Apply registered patch functions in definition order. for patch_fn in _patch_registry: patch_fn() - # Patch torch.Tensor.to() if hasattr(torch.Tensor, "to"): torch.Tensor.to = _wrap_to_method(torch.Tensor.to) - # Patch torch.Tensor.cuda() if hasattr(torch.Tensor, "cuda"): torch.Tensor.cuda = _wrap_tensor_cuda(torch.Tensor.cuda) - # Patch torch.nn.Module.cuda() if hasattr(torch.nn.Module, "cuda"): torch.nn.Module.cuda = _wrap_module_cuda(torch.nn.Module.cuda) - # Patch tensor factory functions to translate device argument - # We also need to update _device_constructors cache to include - # the original (unwrapped) functions, because PyTorch's __torch_function__ - # dispatch receives the original C function, not our Python wrapper. + # Wrap tensor factories and keep originals for PyTorch device-context dispatch. original_fns = [] for fn_name in _FACTORY_FUNCTIONS: if hasattr(torch, fn_name): @@ -1791,22 +923,17 @@ def apply_patches(): original_fns.append(original_fn) setattr(torch, fn_name, _wrap_factory_function(original_fn)) - # Update _device_constructors to include original functions - # This ensures the device context manager (with torch.device(...):) works - # because __torch_function__ receives the original C function + # PyTorch's __torch_function__ path receives original C functions, not our wrappers. try: from torch.utils._device import _device_constructors - # Get the current set of constructors constructors = _device_constructors() - # Add original (unwrapped) functions to the constructors set - # PyTorch's __torch_function__ receives these, not our wrappers for orig_fn in original_fns: constructors.add(orig_fn) except (ImportError, AttributeError): - pass # Older PyTorch versions may not have this + pass # Older PyTorch versions may not expose this helper. _patched = True @@ -1816,7 +943,7 @@ def is_patched() -> bool: return _patched -# Additional exports for advanced usage +# Additional exports for advanced usage. def get_original_init_process_group(): """Get the original torch.distributed.init_process_group function.""" return _original_init_process_group diff --git a/src/torchada/_platform.py b/src/torchada/_platform.py index 45edc86..498fb67 100644 --- a/src/torchada/_platform.py +++ b/src/torchada/_platform.py @@ -1,7 +1,9 @@ """ -Platform detection module for torchada. +Platform detection utilities. -Detects whether the current environment supports CUDA (NVIDIA) or MUSA (Moore Threads). +MUSA is treated as a platform when torch_musa is installed, even when no Moore +Threads GPU is currently available. That keeps build-only environments on the +same compatibility path as runtime GPU environments. """ import os @@ -31,7 +33,7 @@ def detect_platform() -> Platform: Returns: Platform: The detected or configured platform. """ - # Check for forced platform via environment variable + # Honor explicit platform overrides first. forced_platform = os.environ.get("TORCHADA_PLATFORM", "").lower() if forced_platform == "cuda": return Platform.CUDA @@ -40,22 +42,20 @@ def detect_platform() -> Platform: elif forced_platform == "cpu": return Platform.CPU - # Auto-detect platform - # Check MUSA first (Moore Threads) + # Prefer MUSA before CUDA so torch_musa installations take the adapter path. if _is_musa_available(): return Platform.MUSA - # Check CUDA (NVIDIA) + # Fall back to native CUDA only when MUSA is not present. if _is_cuda_available(): return Platform.CUDA - # Fallback to CPU return Platform.CPU def _is_musa_available() -> bool: """ - Check if this is a MUSA (Moore Threads) platform. + Return whether the environment should use the MUSA compatibility path. Detects the MUSA platform by checking if torch_musa is installed, rather than requiring a GPU to be present. This allows torchada to @@ -70,13 +70,12 @@ def _is_musa_available() -> bool: try: import torch - # Primary signal: torch.version.musa is set by torch_musa at build time. - # This is the most reliable indicator that we're on a MUSA platform, + # Primary signal: torch.version.musa is set by torch_musa at build time, # regardless of whether a GPU card is present. if hasattr(torch.version, "musa") and torch.version.musa is not None: return True - # Secondary signal: torch_musa is importable + # Secondary signal: torch_musa is importable. try: import torch_musa # noqa: F401 @@ -90,7 +89,7 @@ def _is_musa_available() -> bool: def _is_cuda_available() -> bool: - """Check if CUDA (NVIDIA) is available.""" + """Return whether native CUDA is available.""" try: import torch @@ -100,22 +99,22 @@ def _is_cuda_available() -> bool: def is_musa_platform() -> bool: - """Check if we're on MUSA platform.""" + """Return whether the detected platform is MUSA.""" return detect_platform() == Platform.MUSA def is_cuda_platform() -> bool: - """Check if we're on CUDA platform.""" + """Return whether the detected platform is native CUDA.""" return detect_platform() == Platform.CUDA def is_cpu_platform() -> bool: - """Check if we're on CPU-only platform.""" + """Return whether the detected platform is CPU-only.""" return detect_platform() == Platform.CPU def get_device_name() -> str: - """Get the device name string ('cuda', 'musa', or 'cpu').""" + """Return the detected device type string.""" return detect_platform().value @@ -145,7 +144,7 @@ def get_torch_device_module(): def is_gpu_device(device) -> bool: """ - Check if a device is a GPU device (either CUDA or MUSA). + Return whether a device is a CUDA-like GPU device. This is a helper function for code that needs to check if a device is a GPU device. On MUSA platform, device.type == "cuda" comparisons @@ -160,23 +159,20 @@ def is_gpu_device(device) -> bool: True if the device is cuda or musa, False otherwise Example: - # Instead of: - # if tensor.device.type == "cuda": - # Use: - # if torchada.is_gpu_device(tensor.device): - # ... + if torchada.is_gpu_device(tensor.device): + ... """ import torch - # Handle tensors, modules, etc. that have a .device attribute + # Accept tensors, modules, and other objects exposing ``.device``. if hasattr(device, "device"): device = device.device - # Handle torch.device objects + # Accept torch.device objects directly. if isinstance(device, torch.device): return device.type in ("cuda", "musa") - # Handle string device specs + # Accept string device specifications. if isinstance(device, str): return ( device == "cuda" @@ -189,5 +185,5 @@ def is_gpu_device(device) -> bool: def is_cuda_like_device(device) -> bool: - """Alias for is_gpu_device() for clarity.""" + """Return whether ``device`` names a CUDA-like GPU device.""" return is_gpu_device(device) diff --git a/src/torchada/_runtime.py b/src/torchada/_runtime.py index 128be47..27d3376 100644 --- a/src/torchada/_runtime.py +++ b/src/torchada/_runtime.py @@ -1,132 +1,116 @@ """ Runtime name conversion utilities for CUDA to MUSA. -This module provides utility functions for converting CUDA function/library -names to their MUSA equivalents at runtime. +This module centralizes CUDA-family runtime symbol translation. The public +helpers are exported from ``torchada`` for manual use, and patching code uses +the same table when adapting ``ctypes.CDLL`` and ``torch.cuda.cudart()``. +""" -Note: torchada automatically patches ctypes.CDLL to translate function names -when loading MUSA libraries (libmusart.so, libmccl.so, etc.). Most users don't -need to use these functions directly - just import torchada and use ctypes -normally with CUDA function names. +from typing import Callable, Dict, Tuple -Example of automatic patching (no code changes needed): - import torchada - import ctypes +PREFIX_MAPPINGS: Dict[str, str] = { + "cuda": "musa", + "nccl": "mccl", + "cublas": "mublas", + "curand": "murand", +} - # Load MUSA runtime library - lib = ctypes.CDLL("libmusart.so") +MUSA_LIBRARY_PATTERNS: Dict[str, Tuple[str, ...]] = { + "musart": ("libmusart", "musart.so", "libmusa_runtime"), + "mccl": ("libmccl", "mccl.so"), + "mublas": ("libmublas", "mublas.so"), + "murand": ("libmurand", "murand.so"), +} - # Access using CUDA function names - automatically translated! - func = lib.cudaIpcOpenMemHandle # -> musaIpcOpenMemHandle -These utility functions are exported for manual use if needed: - from torchada import cuda_to_musa_name, nccl_to_mccl_name +def translate_prefix_name(name: str, source_prefix: str, target_prefix: str) -> str: + """ + Translate ``name`` from one prefix convention to another. - musa_name = cuda_to_musa_name("cudaIpcOpenMemHandle") # -> "musaIpcOpenMemHandle" - mccl_name = nccl_to_mccl_name("ncclAllReduce") # -> "mcclAllReduce" -""" + Names that do not start with ``source_prefix`` are returned unchanged. + """ + if name.startswith(source_prefix): + return target_prefix + name[len(source_prefix) :] + return name def cuda_to_musa_name(name: str) -> str: """ Convert a CUDA function/symbol name to its MUSA equivalent. - This handles the common naming convention where CUDA functions start with - "cuda" and MUSA equivalents start with "musa". - - Args: - name: The CUDA function name (e.g., "cudaIpcOpenMemHandle") - - Returns: - The MUSA equivalent name (e.g., "musaIpcOpenMemHandle") - Examples: >>> cuda_to_musa_name("cudaMalloc") 'musaMalloc' - >>> cuda_to_musa_name("cudaIpcOpenMemHandle") - 'musaIpcOpenMemHandle' - >>> cuda_to_musa_name("cudaError_t") - 'musaError_t' >>> cuda_to_musa_name("someOtherFunc") 'someOtherFunc' """ - if name.startswith("cuda"): - return "musa" + name[4:] - return name + return translate_prefix_name(name, "cuda", "musa") def nccl_to_mccl_name(name: str) -> str: """ Convert an NCCL function/symbol name to its MCCL equivalent. - This handles the common naming convention where NCCL functions start with - "nccl" and MCCL equivalents start with "mccl". - - Args: - name: The NCCL function name (e.g., "ncclAllReduce") - - Returns: - The MCCL equivalent name (e.g., "mcclAllReduce") - Examples: >>> nccl_to_mccl_name("ncclAllReduce") 'mcclAllReduce' - >>> nccl_to_mccl_name("ncclCommInitRank") - 'mcclCommInitRank' - >>> nccl_to_mccl_name("ncclUniqueId") - 'mcclUniqueId' >>> nccl_to_mccl_name("someOtherFunc") 'someOtherFunc' """ - if name.startswith("nccl"): - return "mccl" + name[4:] - return name + return translate_prefix_name(name, "nccl", "mccl") def cublas_to_mublas_name(name: str) -> str: """ Convert a cuBLAS function/symbol name to its muBLAS equivalent. - This handles the common naming convention where cuBLAS functions start with - "cublas" and muBLAS equivalents start with "mublas". - - Args: - name: The cuBLAS function name (e.g., "cublasCreate") - - Returns: - The muBLAS equivalent name (e.g., "mublasCreate") - Examples: >>> cublas_to_mublas_name("cublasCreate") 'mublasCreate' - >>> cublas_to_mublas_name("cublasSgemm") - 'mublasSgemm' >>> cublas_to_mublas_name("someOtherFunc") 'someOtherFunc' """ - if name.startswith("cublas"): - return "mublas" + name[6:] - return name + return translate_prefix_name(name, "cublas", "mublas") def curand_to_murand_name(name: str) -> str: """ Convert a cuRAND function/symbol name to its muRAND equivalent. - Args: - name: The cuRAND function name (e.g., "curandCreate") - - Returns: - The muRAND equivalent name (e.g., "murandCreate") - Examples: >>> curand_to_murand_name("curandCreate") 'murandCreate' - >>> curand_to_murand_name("curand_init") - 'murand_init' >>> curand_to_murand_name("someOtherFunc") 'someOtherFunc' """ - if name.startswith("curand"): - return "murand" + name[6:] - return name + return translate_prefix_name(name, "curand", "murand") + + +_RUNTIME_TRANSLATORS: Dict[str, Callable[[str], str]] = { + "musart": cuda_to_musa_name, + "mccl": nccl_to_mccl_name, + "mublas": cublas_to_mublas_name, + "murand": curand_to_murand_name, +} + + +def detect_musa_library_type(lib_path: str) -> str: + """Detect which MUSA runtime-family library is referenced by ``lib_path``.""" + lib_path_lower = str(lib_path).lower() + for library_type, patterns in MUSA_LIBRARY_PATTERNS.items(): + if any(pattern in lib_path_lower for pattern in patterns): + return library_type + return "unknown" + + +def is_musa_runtime_library_path(lib_path: str) -> bool: + """Return whether ``lib_path`` points at a library needing symbol translation.""" + return detect_musa_library_type(lib_path) != "unknown" + + +def translate_runtime_symbol_name(name: str, library_type: str) -> str: + """Translate ``name`` for a detected MUSA runtime-family library type.""" + translator = _RUNTIME_TRANSLATORS.get(library_type) + if translator is None: + return name + return translator(name) diff --git a/src/torchada/csrc/musa_ops.mu b/src/torchada/csrc/musa_ops.mu index a03f325..a930761 100644 --- a/src/torchada/csrc/musa_ops.mu +++ b/src/torchada/csrc/musa_ops.mu @@ -12,10 +12,7 @@ namespace torchada { -// ============================================================================ -// Example: MUSA kernel for neg (negation) -// This demonstrates how to override aten::neg for PrivateUse1 (MUSA) tensors -// ============================================================================ +// Example MUSA kernel for overriding aten::neg on PrivateUse1 tensors. template __global__ void neg_kernel( @@ -31,20 +28,20 @@ __global__ void neg_kernel( at::Tensor neg_musa_impl(const at::Tensor& self) { log_op_call("neg"); - // Ensure contiguous tensor + // Ensure the input is contiguous. auto self_contig = self.contiguous(); - // Allocate output tensor + // Allocate the output tensor. auto output = at::empty_like(self_contig); if (self_contig.numel() == 0) { return output; } - // Get MUSA stream + // Use the current MUSA stream. musaStream_t stream = at::musa::getCurrentMUSAStream(); - // Launch kernel + // Launch the kernel. const int64_t numel = self_contig.numel(); const int threads = 256; const int blocks = (numel + threads - 1) / threads; @@ -58,7 +55,7 @@ at::Tensor neg_musa_impl(const at::Tensor& self) { numel); }); - // Check for launch errors + // Check launch errors. musaError_t err = musaGetLastError(); if (err != musaSuccess) { TORCH_CHECK(false, "MUSA kernel launch failed: ", musaGetErrorString(err)); @@ -69,18 +66,15 @@ at::Tensor neg_musa_impl(const at::Tensor& self) { } // namespace torchada -// ============================================================================ -// Register operator overrides for PrivateUse1 (MUSA) +// Register operator overrides for PrivateUse1 (MUSA). // // Each operator checks TORCHADA_DISABLE_OP_OVERRIDE_=1 at registration -// time. If set, the override is not registered and torch_musa's default -// implementation is used. +// time. If set, torch_musa's default implementation remains active. // // Uncomment m.impl() lines to activate custom implementations. -// ============================================================================ TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - // Example: Register neg override only if not disabled + // Example: register neg override only when enabled. // if (torchada::is_override_enabled("neg")) { // m.impl("neg", torchada::neg_musa_impl); // } diff --git a/src/torchada/csrc/ops.cpp b/src/torchada/csrc/ops.cpp index 4512fbc..479db45 100644 --- a/src/torchada/csrc/ops.cpp +++ b/src/torchada/csrc/ops.cpp @@ -4,8 +4,8 @@ // Custom operator implementations can be added here or in separate files. // // To add a new operator override: -// 1. Write the implementation function -// 2. Register it using TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) +// 1. Write the implementation function. +// 2. Register it using TORCH_LIBRARY_IMPL(aten, PrivateUse1, m). // // Note: Operators registered here will override torch_musa's implementations. // Use with caution and ensure correctness. @@ -19,9 +19,7 @@ namespace torchada { -// ============================================================================ -// Memory pool allocation functions (CUDA-compatible API on MUSA) -// ============================================================================ +// CUDA-compatible memory pool allocation functions on MUSA. static void _musa_beginAllocateCurrentThreadToPool( c10::DeviceIndex device, @@ -47,9 +45,7 @@ static void _musa_releasePool( c10::musa::MUSACachingAllocator::releasePool(device, mempool_id); } -// ============================================================================ -// Utility functions exposed to Python -// ============================================================================ +// Utility functions exposed to Python. static bool cpp_ops_loaded = false; @@ -68,9 +64,7 @@ void mark_loaded() { } // namespace torchada -// ============================================================================ -// Python bindings -// ============================================================================ +// Python bindings. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "torchada C++ operator overrides"; diff --git a/src/torchada/csrc/ops.h b/src/torchada/csrc/ops.h index 28788d0..a7b0501 100644 --- a/src/torchada/csrc/ops.h +++ b/src/torchada/csrc/ops.h @@ -4,18 +4,18 @@ // implementations that override the default PrivateUse1 (MUSA) implementations. // // Usage: -// 1. Include this header in your .cpp or .mu file -// 2. Use TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) to register overrides -// 3. Use is_override_enabled("op_name") to check if override should be registered -// 4. The extension will be built and loaded automatically by torchada +// 1. Include this header in your .cpp or .mu file. +// 2. Use TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) to register overrides. +// 3. Use is_override_enabled("op_name") before registering an override. +// 4. Let torchada build and load the extension automatically. // // Example: // #include "ops.h" // // at::Tensor my_custom_add(const at::Tensor& self, const at::Tensor& other, // const at::Scalar& alpha) { -// log_op_call("add.Tensor"); // Optional: log when called -// // Custom implementation +// log_op_call("add.Tensor"); // Optional debug logging. +// // Custom implementation. // auto result = at::empty_like(self); // result.copy_(self); // result.add_(other, alpha); @@ -23,7 +23,7 @@ // } // // TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { -// // Check env var at registration time - if disabled, don't register +// // Check the environment at registration time. // if (torchada::is_override_enabled("add")) { // m.impl("add.Tensor", my_custom_add); // } @@ -39,12 +39,11 @@ namespace torchada { -// Version information +// Version information. constexpr const char* VERSION = "0.1.0"; -// Check if operator override is enabled via environment variable +// Return whether an operator override is enabled by environment settings. inline bool is_override_enabled(const char* op_name) { - // Check TORCHADA_DISABLE_OP_OVERRIDE_ environment variable std::string env_var = "TORCHADA_DISABLE_OP_OVERRIDE_"; env_var += op_name; const char* val = std::getenv(env_var.c_str()); @@ -54,7 +53,7 @@ inline bool is_override_enabled(const char* op_name) { return true; } -// Logging helper for debugging +// Log operator calls when C++ operator debugging is enabled. inline void log_op_call(const char* op_name) { const char* debug = std::getenv("TORCHADA_DEBUG_CPP_OPS"); if (debug != nullptr && std::string(debug) == "1") { diff --git a/src/torchada/cuda/__init__.py b/src/torchada/cuda/__init__.py index 1c35557..16eeeae 100644 --- a/src/torchada/cuda/__init__.py +++ b/src/torchada/cuda/__init__.py @@ -8,16 +8,17 @@ This module is provided for internal use and backwards compatibility. Usage (preferred): - import torchada # Apply patches + import torchada # Apply patches. import torch - # torch.cuda APIs work on MUSA after importing torchada + # torch.cuda APIs work on MUSA after importing torchada. torch.cuda.set_device(0) tensor = tensor.cuda() """ from typing import Optional, Union +from .._device_compat import _translate_device from .._platform import Platform, detect_platform @@ -27,7 +28,7 @@ def _get_backend(): if platform == Platform.MUSA: import torch - import torch_musa + import torch_musa # noqa: F401 - registers torch.musa return torch.musa elif platform == Platform.CUDA: @@ -35,16 +36,35 @@ def _get_backend(): return torch.cuda else: - # Return torch.cuda for API compatibility, even if not available + # Preserve torch.cuda-shaped APIs even when no GPU backend is present. import torch return torch.cuda -# Core device functions +def _backend_attr(name: str): + """Return an attribute from the active CUDA-compatible backend.""" + return getattr(_get_backend(), name) + + +def _call_backend(name: str, *args, **kwargs): + """Call a method on the active CUDA-compatible backend.""" + return _backend_attr(name)(*args, **kwargs) + + +def _call_backend_with_fallback(primary: str, fallback: str, *args, **kwargs): + """Call a backend method, falling back to a compatible replacement if absent.""" + backend = _get_backend() + fn = getattr(backend, primary, None) + if fn is None: + fn = getattr(backend, fallback) + return fn(*args, **kwargs) + + +# Core device functions. def is_available() -> bool: """Check if CUDA/MUSA is available.""" - return _get_backend().is_available() + return _call_backend("is_available") def device_count() -> int: @@ -57,109 +77,111 @@ def device_count() -> int: def current_device() -> int: """Return the index of the currently selected device.""" - return _get_backend().current_device() + return _call_backend("current_device") def set_device(device: Union[int, str, "torch.device"]) -> None: """Set the current device.""" - _get_backend().set_device(device) + _call_backend("set_device", _translate_device(device)) def get_device_name(device: Optional[Union[int, str]] = None) -> str: """Get the name of a device.""" - return _get_backend().get_device_name(device) + return _call_backend("get_device_name", _translate_device(device)) def get_device_capability(device: Optional[Union[int, str]] = None) -> tuple: """Get the CUDA/MUSA compute capability of a device.""" - return _get_backend().get_device_capability(device) + return _call_backend("get_device_capability", _translate_device(device)) def get_device_properties(device: Optional[Union[int, str]] = None): """Get the properties of a device.""" - return _get_backend().get_device_properties(device) + return _call_backend("get_device_properties", _translate_device(device)) -# Memory management +# Memory management functions. def memory_allocated(device: Optional[Union[int, str]] = None) -> int: """Return the current GPU memory occupied by tensors in bytes.""" - return _get_backend().memory_allocated(device) + return _call_backend("memory_allocated", _translate_device(device)) def max_memory_allocated(device: Optional[Union[int, str]] = None) -> int: """Return the maximum GPU memory occupied by tensors in bytes.""" - return _get_backend().max_memory_allocated(device) + return _call_backend("max_memory_allocated", _translate_device(device)) def memory_reserved(device: Optional[Union[int, str]] = None) -> int: """Return the current GPU memory managed by the caching allocator in bytes.""" - return _get_backend().memory_reserved(device) + return _call_backend("memory_reserved", _translate_device(device)) def max_memory_reserved(device: Optional[Union[int, str]] = None) -> int: """Return the maximum GPU memory managed by the caching allocator in bytes.""" - return _get_backend().max_memory_reserved(device) + return _call_backend("max_memory_reserved", _translate_device(device)) def memory_cached(device: Optional[Union[int, str]] = None) -> int: """Deprecated: Use memory_reserved instead.""" - return _get_backend().memory_cached(device) + return _call_backend_with_fallback( + "memory_cached", "memory_reserved", _translate_device(device) + ) def max_memory_cached(device: Optional[Union[int, str]] = None) -> int: """Deprecated: Use max_memory_reserved instead.""" - return _get_backend().max_memory_cached(device) + return _call_backend_with_fallback( + "max_memory_cached", "max_memory_reserved", _translate_device(device) + ) def empty_cache() -> None: """Release all unoccupied cached memory.""" - _get_backend().empty_cache() + _call_backend("empty_cache") def reset_peak_memory_stats(device: Optional[Union[int, str]] = None) -> None: """Reset the peak memory stats.""" - _get_backend().reset_peak_memory_stats(device) + _call_backend("reset_peak_memory_stats", _translate_device(device)) def reset_max_memory_allocated(device: Optional[Union[int, str]] = None) -> None: """Reset the starting point in tracking maximum GPU memory occupied.""" - _get_backend().reset_max_memory_allocated(device) + _call_backend("reset_max_memory_allocated", _translate_device(device)) def reset_max_memory_cached(device: Optional[Union[int, str]] = None) -> None: """Reset the starting point in tracking maximum GPU memory managed.""" - _get_backend().reset_max_memory_cached(device) + _call_backend_with_fallback( + "reset_max_memory_cached", "reset_peak_memory_stats", _translate_device(device) + ) -# Synchronization +# Synchronization functions. def synchronize(device: Optional[Union[int, str]] = None) -> None: """Wait for all kernels in all streams on a device to complete.""" - _get_backend().synchronize(device) + _call_backend("synchronize", _translate_device(device)) -# Stream and Event classes - will be set up dynamically +# Stream and event aliases are resolved from the active backend at import time. def _setup_stream_event_classes(): """Set up Stream and Event classes from the backend.""" backend = _get_backend() - # These will be the actual classes from the backend + # Use backend classes directly so isinstance checks keep backend semantics. global Stream, Event, current_stream, default_stream, stream - Stream = backend.Stream if hasattr(backend, "Stream") else None - Event = backend.Event if hasattr(backend, "Event") else None - - if hasattr(backend, "current_stream"): - current_stream = backend.current_stream - if hasattr(backend, "default_stream"): - default_stream = backend.default_stream - if hasattr(backend, "stream"): - stream = backend.stream + Stream = getattr(backend, "Stream", None) + Event = getattr(backend, "Event", None) + current_stream = getattr(backend, "current_stream", None) + default_stream = getattr(backend, "default_stream", None) + stream = getattr(backend, "stream", None) -# Initialize stream/event classes +# Initialize stream and event aliases. try: _setup_stream_event_classes() -except: +except Exception: Stream = None Event = None current_stream = None diff --git a/src/torchada/cuda/amp.py b/src/torchada/cuda/amp.py index 006c46a..88a00e4 100644 --- a/src/torchada/cuda/amp.py +++ b/src/torchada/cuda/amp.py @@ -13,10 +13,11 @@ def _get_amp_backend(): if platform == Platform.MUSA: import torch + import torch_musa # noqa: F401 - registers torch.musa if hasattr(torch.musa, "amp"): return torch.musa.amp - # Fallback to torch.cuda.amp for API compatibility + # Fall back to torch.cuda.amp for API compatibility. return torch.cuda.amp else: import torch @@ -24,7 +25,7 @@ def _get_amp_backend(): return torch.cuda.amp -# Re-export common AMP classes and functions +# Re-export common AMP classes and functions. def autocast(enabled=True, dtype=None, cache_enabled=True): """ Context manager for automatic mixed precision. @@ -38,7 +39,7 @@ def autocast(enabled=True, dtype=None, cache_enabled=True): if hasattr(backend, "autocast"): return backend.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) else: - # Use torch.autocast with appropriate device type + # Use torch.autocast with the detected device type. import torch from .._platform import get_device_name @@ -76,6 +77,13 @@ def __init__( enabled=enabled, ) + def __getattr__(self, name): + """Delegate backend-specific GradScaler APIs not wrapped explicitly.""" + scaler = self.__dict__.get("_scaler") + if scaler is None: + raise AttributeError(name) + return getattr(scaler, name) + def scale(self, outputs): """Scale the outputs.""" return self._scaler.scale(outputs) diff --git a/src/torchada/cuda/nvtx.py b/src/torchada/cuda/nvtx.py index 78594c3..acab99e 100644 --- a/src/torchada/cuda/nvtx.py +++ b/src/torchada/cuda/nvtx.py @@ -5,16 +5,16 @@ functions for profiling. On MUSA platform, these are stubs that do nothing. Usage: - import torchada # Apply patches first + import torchada # Apply patches first. import torch.cuda.nvtx as nvtx nvtx.range_push("my_range") - # ... code to profile ... + # Code to profile. nvtx.range_pop() - # Or use the context manager + # Or use the context manager. with nvtx.range("my_range"): - # ... code to profile ... + # Code to profile. """ from contextlib import contextmanager diff --git a/src/torchada/cuda/random.py b/src/torchada/cuda/random.py index dce3d9e..a82d5f6 100644 --- a/src/torchada/cuda/random.py +++ b/src/torchada/cuda/random.py @@ -7,7 +7,7 @@ platforms, this module is only reachable when torchada patching is active. Usage: - import torchada # Apply patches first + import torchada # Apply patches first. import torch.cuda.random as cuda_random cuda_random.manual_seed(1234) @@ -19,6 +19,8 @@ import torch from torch import Tensor +from .._device_compat import _translate_device + __all__ = [ "get_rng_state", "get_rng_state_all", @@ -32,33 +34,40 @@ ] -def get_rng_state(device: Union[int, str, torch.device] = "musa") -> Tensor: +def _get_musa_backend(): + """Return torch.musa, importing torch_musa first so the module is registered.""" + import torch_musa # noqa: F401 + + return torch.musa + + +def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor: r"""Return the random number generator state of the specified GPU as a ByteTensor. Args: device (torch.device or int, optional): The device to return the RNG state of. - Default: ``'musa'`` for the current device. + Default: ``'cuda'`` for CUDA API compatibility. .. warning:: This function eagerly initializes the backend device. """ - return torch.musa.get_rng_state(device) + return _get_musa_backend().get_rng_state(_translate_device(device)) def get_rng_state_all() -> List[Tensor]: r"""Return a list of ByteTensor representing the random number states of all devices.""" - return torch.musa.get_rng_state_all() + return _get_musa_backend().get_rng_state_all() -def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = "musa") -> None: +def set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = "cuda") -> None: r"""Set the random number generator state of the specified GPU. Args: new_state (torch.ByteTensor): The desired state device (torch.device or int, optional): The device to set the RNG state. - Default: ``'musa'`` for the current device. + Default: ``'cuda'`` for CUDA API compatibility. """ - return torch.musa.set_rng_state(new_state, device) + return _get_musa_backend().set_rng_state(new_state, _translate_device(device)) def set_rng_state_all(new_states: Iterable[Tensor]) -> None: @@ -67,7 +76,7 @@ def set_rng_state_all(new_states: Iterable[Tensor]) -> None: Args: new_states (Iterable of torch.ByteTensor): The desired state for each device. """ - return torch.musa.set_rng_state_all(new_states) + return _get_musa_backend().set_rng_state_all(new_states) def manual_seed(seed: int) -> None: @@ -83,7 +92,7 @@ def manual_seed(seed: int) -> None: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use :func:`manual_seed_all`. """ - return torch.musa.manual_seed(seed) + return _get_musa_backend().manual_seed(seed) def manual_seed_all(seed: int) -> None: @@ -95,7 +104,7 @@ def manual_seed_all(seed: int) -> None: Args: seed (int): The desired seed. """ - return torch.musa.manual_seed_all(seed) + return _get_musa_backend().manual_seed_all(seed) def seed() -> None: @@ -104,7 +113,7 @@ def seed() -> None: It's safe to call this function if the backend is unavailable; in that case, behavior is backend-defined. """ - return torch.musa.seed() + return _get_musa_backend().seed() def seed_all() -> None: @@ -113,7 +122,7 @@ def seed_all() -> None: It's safe to call this function if the backend is unavailable; in that case, behavior is backend-defined. """ - return torch.musa.seed_all() + return _get_musa_backend().seed_all() def initial_seed() -> int: @@ -122,4 +131,4 @@ def initial_seed() -> int: .. warning:: This function eagerly initializes the backend device. """ - return torch.musa.initial_seed() + return _get_musa_backend().initial_seed() diff --git a/tests/test_cuda_patching.py b/tests/test_cuda_patching.py index 6217780..dee050f 100644 --- a/tests/test_cuda_patching.py +++ b/tests/test_cuda_patching.py @@ -449,6 +449,22 @@ def test_mccl_module_available(self): if torchada.is_musa_platform(): assert hasattr(torch.cuda, "mccl") + def test_nccl_module_alias_available(self): + """Test torch.cuda.nccl attribute and import both resolve to torch.musa.mccl.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + assert hasattr(torch.cuda, "nccl") + assert torch.cuda.nccl is torch.musa.mccl + + import torch.cuda.nccl as nccl + + assert nccl is torch.musa.mccl + class TestRNGFunctions: """Test RNG functions are available through torch.cuda.""" @@ -513,6 +529,12 @@ def test_cuda_random_module_available(self): assert hasattr(torch.cuda.random, "seed") assert hasattr(torch.cuda.random, "initial_seed") + def test_cuda_random_module_import_available(self): + import torch + import torch.cuda.random as cuda_random + + assert cuda_random is torch.cuda.random + def test_cuda_random_aliases_musa(self): import torch @@ -541,6 +563,190 @@ def test_cuda_random_functionality(self): assert isinstance(torch.cuda.random.initial_seed(), int) +class TestTorchadaCudaDirectModule: + """Test direct torchada.cuda compatibility wrappers.""" + + def test_tensor_cuda_wrapper_translates_explicit_device(self, monkeypatch): + import torchada._device_compat as device_compat + + calls = {} + + class FakeTensor: + def musa(self, device=None, non_blocking=False): + calls["musa"] = (device, non_blocking) + return "musa-result" + + def original_cuda(self, device=None, non_blocking=False): + raise AssertionError("original CUDA path should not be used on MUSA") + + monkeypatch.setattr(device_compat, "is_musa_platform", lambda: True) + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + + wrapped = device_compat._wrap_tensor_cuda(original_cuda) + + assert wrapped(FakeTensor(), device="cuda:1", non_blocking=True) == "musa-result" + assert calls["musa"] == ("musa:1", True) + + def test_cuda_wrapper_to_fallback_uses_translated_device(self, monkeypatch): + import torch + + import torchada._device_compat as device_compat + + calls = {} + + class FakeObject: + def to(self, device, **kwargs): + calls["to"] = (device, kwargs) + return "to-result" + + def original_cuda(self, device=None, non_blocking=False): + raise AssertionError("original CUDA path should not be used on MUSA") + + monkeypatch.setattr(device_compat, "is_musa_platform", lambda: True) + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + + tensor_cuda = device_compat._wrap_tensor_cuda(original_cuda) + assert tensor_cuda(FakeObject(), device="cuda:1", non_blocking=True) == "to-result" + assert calls["to"] == ( + "musa:1", + {"non_blocking": True, "memory_format": torch.preserve_format}, + ) + + calls.clear() + module_cuda = device_compat._wrap_module_cuda(original_cuda) + assert module_cuda(FakeObject(), device=2) == "to-result" + assert calls["to"] == ("musa:2", {}) + + def test_tensor_cuda_wrapper_preserves_memory_format(self, monkeypatch): + import torch + + import torchada._device_compat as device_compat + + calls = {} + + class FakeTensor: + def musa(self, **kwargs): + calls["musa"] = kwargs + return "musa-result" + + def original_cuda(self, device=None, non_blocking=False, memory_format=None): + raise AssertionError("original CUDA path should not be used on MUSA") + + monkeypatch.setattr(device_compat, "is_musa_platform", lambda: True) + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + + wrapped = device_compat._wrap_tensor_cuda(original_cuda) + + assert ( + wrapped(FakeTensor(), device="cuda", memory_format=torch.channels_last) == "musa-result" + ) + assert calls["musa"] == { + "device": "musa", + "non_blocking": False, + "memory_format": torch.channels_last, + } + + def test_device_arguments_translate_for_direct_cuda_module(self, monkeypatch): + import torchada._device_compat as device_compat + import torchada.cuda as torchada_cuda + + calls = {} + + class Backend: + def set_device(self, device): + calls["set_device"] = device + + def get_device_name(self, device=None): + calls["get_device_name"] = device + return "MUSA" + + def memory_allocated(self, device=None): + calls["memory_allocated"] = device + return 0 + + def synchronize(self, device=None): + calls["synchronize"] = device + + backend = Backend() + + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + monkeypatch.setattr(torchada_cuda, "_get_backend", lambda: backend) + + torchada_cuda.set_device("cuda:1") + assert calls["set_device"] == "musa:1" + + assert torchada_cuda.get_device_name("cuda") == "MUSA" + assert calls["get_device_name"] == "musa" + + assert torchada_cuda.memory_allocated("cuda:0") == 0 + assert calls["memory_allocated"] == "musa:0" + + torchada_cuda.synchronize("cuda") + assert calls["synchronize"] == "musa" + + def test_memory_cached_falls_back_to_reserved_names(self, monkeypatch): + import torchada._device_compat as device_compat + import torchada.cuda as torchada_cuda + + calls = {} + + class Backend: + def memory_reserved(self, device=None): + calls["memory_reserved"] = device + return 123 + + def max_memory_reserved(self, device=None): + calls["max_memory_reserved"] = device + return 456 + + def reset_peak_memory_stats(self, device=None): + calls["reset_peak_memory_stats"] = device + + backend = Backend() + + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + monkeypatch.setattr(torchada_cuda, "_get_backend", lambda: backend) + + assert torchada_cuda.memory_cached("cuda") == 123 + assert calls["memory_reserved"] == "musa" + + assert torchada_cuda.max_memory_cached("cuda:1") == 456 + assert calls["max_memory_reserved"] == "musa:1" + + torchada_cuda.reset_max_memory_cached("cuda") + assert calls["reset_peak_memory_stats"] == "musa" + + +class TestTorchadaCudaRandomStub: + """Test direct torchada.cuda.random stub behavior.""" + + def test_rng_state_device_argument_is_translated(self, monkeypatch): + import torchada._device_compat as device_compat + import torchada.cuda.random as cuda_random + + calls = {} + expected_state = object() + + class Backend: + def get_rng_state(self, device): + calls["get_rng_state"] = device + return expected_state + + def set_rng_state(self, state, device): + calls["set_rng_state"] = (state, device) + + backend = Backend() + + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + monkeypatch.setattr(cuda_random, "_get_musa_backend", lambda: backend) + + assert cuda_random.get_rng_state("cuda:0") is expected_state + assert calls["get_rng_state"] == "musa:0" + + cuda_random.set_rng_state(expected_state) + assert calls["set_rng_state"] == (expected_state, "musa") + + class TestMemoryFunctions: """Test additional memory functions.""" @@ -597,6 +803,137 @@ def test_mem_get_info(self): raise +class TestCudaPublicApiAliases: + """Test CUDA public API aliases that are absent from torch.musa.""" + + def test_deprecated_memory_cached_aliases(self): + """Deprecated CUDA memory_cached names should map to MUSA reserved memory APIs.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + assert torch.cuda.memory_cached() == torch.cuda.memory_reserved() + assert torch.cuda.max_memory_cached() == torch.cuda.max_memory_reserved() + + from torch.cuda.memory import max_memory_cached, memory_cached + + assert memory_cached() == torch.cuda.memory_reserved() + assert max_memory_cached() == torch.cuda.max_memory_reserved() + + def test_host_memory_stats_noops(self): + """CUDA host allocator stat APIs should be importable and harmless on MUSA.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + assert dict(torch.cuda.host_memory_stats()) == {} + assert torch.cuda.host_memory_stats_as_nested_dict() == {} + torch.cuda.reset_accumulated_host_memory_stats() + torch.cuda.reset_peak_host_memory_stats() + + from torch.cuda.memory import ( + host_memory_stats, + host_memory_stats_as_nested_dict, + reset_accumulated_host_memory_stats, + reset_peak_host_memory_stats, + ) + + assert dict(host_memory_stats()) == {} + assert host_memory_stats_as_nested_dict() == {} + reset_accumulated_host_memory_stats() + reset_peak_host_memory_stats() + + def test_static_cuda_build_flags(self): + """CUDA static build flags should remain available after redirection.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + assert torch.cuda.has_half is True + assert torch.cuda.has_magma is False + + def test_cuda_pluggable_allocator_top_level_alias(self): + """Top-level CUDAPluggableAllocator should alias MUSAPluggableAllocator.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + if not hasattr(torch.musa.memory, "MUSAPluggableAllocator"): + pytest.skip("MUSAPluggableAllocator not available") + + assert torch.cuda.CUDAPluggableAllocator is torch.musa.memory.MUSAPluggableAllocator + + def test_streams_module_alias(self): + """torch.cuda.streams imports should resolve to MUSA stream classes.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + import torch.cuda.streams as streams + + assert torch.cuda.streams is streams + assert streams.Stream is torch.musa.Stream + assert streams.Event is torch.musa.Event + assert streams.ExternalStream is torch.musa.ExternalStream + + def test_sparse_module_alias(self): + """torch.cuda.sparse tensor aliases should resolve to MUSA tensor aliases.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + import torch.cuda.sparse as sparse + + assert torch.cuda.sparse is sparse + assert sparse.FloatTensor is torch.musa.FloatTensor + assert sparse.HalfTensor is torch.musa.HalfTensor + assert sparse.BFloat16Tensor is torch.musa.BFloat16Tensor + + def test_init_and_default_generators_available(self): + """torch.cuda.init and default_generators should be available on MUSA.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + torch.cuda.init() + assert isinstance(torch.cuda.default_generators, tuple) + assert len(torch.cuda.default_generators) == torch.cuda.device_count() + + def test_get_stream_from_external_wraps_musa_stream(self): + """torch.cuda.get_stream_from_external should return a MUSA ExternalStream.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + stream = torch.musa.current_stream() + wrapped = torch.cuda.get_stream_from_external(stream.musa_stream, device=0) + assert isinstance(wrapped, torch.musa.ExternalStream) + + class TestStreamAndEvent: """Test Stream and Event classes.""" @@ -1315,6 +1652,62 @@ def test_torch_c_storage_use_count(self): assert callable(use_count) +class TestCudaBuildAndDebugIntrospection: + """Test top-level torch.cuda build/debug helpers missing from torch.musa.""" + + def test_get_gencode_flags_available_on_musa(self): + """torch.cuda.get_gencode_flags() should exist after CUDA->MUSA redirection.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + flags = torch.cuda.get_gencode_flags() + assert isinstance(flags, str) + # NVCC gencode flags are CUDA-specific, so the MUSA shim returns the + # same safe value as a PyTorch build with no CUDA architectures. + assert flags == "" + + from torch.cuda import get_gencode_flags + + assert get_gencode_flags() == flags + + def test_sync_debug_mode_available_on_musa(self): + """torch.cuda sync debug mode getters/setters should not require CUDA C hooks.""" + import torch + + import torchada + + if not torchada.is_musa_platform(): + pytest.skip("Only applicable on MUSA platform") + + original = torch.cuda.get_sync_debug_mode() + try: + torch.cuda.set_sync_debug_mode("warn") + assert torch.cuda.get_sync_debug_mode() == 1 + + torch.cuda.set_sync_debug_mode("error") + assert torch.cuda.get_sync_debug_mode() == 2 + + torch.cuda.set_sync_debug_mode("default") + assert torch.cuda.get_sync_debug_mode() == 0 + + torch.cuda.set_sync_debug_mode(1) + assert torch.cuda.get_sync_debug_mode() == 1 + + with pytest.raises(RuntimeError, match="invalid value of debug_mode"): + torch.cuda.set_sync_debug_mode("invalid") + + from torch.cuda import get_sync_debug_mode, set_sync_debug_mode + + set_sync_debug_mode(2) + assert get_sync_debug_mode() == 2 + finally: + torch.cuda.set_sync_debug_mode(original) + + class TestProfilerActivity: """Test torch.profiler.ProfilerActivity.CUDA patching.""" @@ -2185,6 +2578,31 @@ def test_cpp_ops_source_files_exist(self): assert osp.isfile(ops_h), f"ops.h not found: {ops_h}" assert osp.isfile(ops_cpp), f"ops.cpp not found: {ops_cpp}" + def test_cpp_ops_source_discovery_groups_files(self, tmp_path): + """C++ ops source discovery should separate host and MUSA sources.""" + from torchada._cpp_ops import _discover_extension_sources + + for name in ["z_kernel.mu", "a_ops.cpp", "b_kernel.cu", "ignore.txt"]: + (tmp_path / name).write_text("") + + sources = _discover_extension_sources(str(tmp_path)) + + assert [p.rsplit("/", 1)[-1] for p in sources.cpp_sources] == ["a_ops.cpp"] + assert [p.rsplit("/", 1)[-1] for p in sources.musa_sources] == [ + "b_kernel.cu", + "z_kernel.mu", + ] + assert sources.has_sources + assert sources.needs_musa_loader + + def test_cpp_ops_musa_arch_flag_prefers_env(self, monkeypatch): + """MTGPU_TARGET should override runtime architecture detection.""" + from torchada._cpp_ops import _get_musa_arch_flag + + monkeypatch.setenv("MTGPU_TARGET", "mp_test") + + assert _get_musa_arch_flag() == "--offload-arch=mp_test" + def test_cpp_ops_header_content(self): """Test that the C++ header has expected content.""" import os.path as osp @@ -2548,6 +2966,24 @@ def test_override_takes_precedence_over_everything(self): wrapper._set_override("synchronize", "patched_impl") assert wrapper.synchronize == "patched_impl" + def test_synchronize_override_translates_cuda_device_strings(self, monkeypatch): + """The synchronize override must accept CUDA spellings before calling MUSA.""" + import torchada._device_compat as device_compat + from torchada._accelerator_compat import _make_patched_accelerator_synchronize + + calls = [] + + class FakeMusa: + def synchronize(self, device=None): + calls.append(device) + + monkeypatch.setattr(device_compat, "_is_musa_platform_cached", True) + synchronize = _make_patched_accelerator_synchronize(FakeMusa()) + + synchronize("cuda:1") + + assert calls == ["musa:1"] + def test_missing_everywhere_raises_attribute_error(self): """Attribute missing from both modules must raise AttributeError.""" wrapper, _, _ = self._make_wrapper() diff --git a/tests/test_mappings.py b/tests/test_mappings.py index b4a4414..6f7075c 100644 --- a/tests/test_mappings.py +++ b/tests/test_mappings.py @@ -775,6 +775,47 @@ def test_curand_to_murand_name(self): # Non-curand names should be unchanged assert curand_to_murand_name("someOtherFunc") == "someOtherFunc" + def test_runtime_library_translation_table(self): + """Runtime library detection and symbol translation should share one table.""" + from torchada._runtime import ( + detect_musa_library_type, + is_musa_runtime_library_path, + translate_runtime_symbol_name, + ) + + assert detect_musa_library_type("/usr/local/musa/lib/libmusart.so") == "musart" + assert detect_musa_library_type("/usr/local/musa/lib/libmccl.so") == "mccl" + assert detect_musa_library_type("/usr/local/musa/lib/libmublas.so") == "mublas" + assert detect_musa_library_type("/usr/local/musa/lib/libmurand.so") == "murand" + assert detect_musa_library_type("libc.so.6") == "unknown" + + assert is_musa_runtime_library_path("/usr/local/musa/lib/libmusa_runtime.so") + assert not is_musa_runtime_library_path("libc.so.6") + + assert translate_runtime_symbol_name("cudaMalloc", "musart") == "musaMalloc" + assert translate_runtime_symbol_name("ncclAllReduce", "mccl") == "mcclAllReduce" + assert translate_runtime_symbol_name("cublasCreate", "mublas") == "mublasCreate" + assert translate_runtime_symbol_name("curandCreate", "murand") == "murandCreate" + assert translate_runtime_symbol_name("cudaMalloc", "unknown") == "cudaMalloc" + + def test_cudart_wrapper_uses_runtime_translation(self): + """torch.cuda.cudart() wrapper should use generic cuda->musa translation.""" + from torchada._cuda_compat import _CudartWrapper + + class FakeMusart: + musaMalloc = object() + unchanged = object() + + fake_musart = FakeMusart() + wrapper = _CudartWrapper(fake_musart) + + assert wrapper.cudaMalloc is fake_musart.musaMalloc + assert wrapper.__dict__["cudaMalloc"] is fake_musart.musaMalloc + assert wrapper.unchanged is fake_musart.unchanged + + with pytest.raises(AttributeError): + _ = wrapper.cudaMissing + class TestCDLLWrapper: """Test ctypes.CDLL wrapper for automatic function name translation.