Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
name: CI

on:
push:
branches: [master, main]
pull_request:

jobs:
shaders:
name: Compile SPIR-V shaders
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install glslang
run: |
sudo apt-get update
sudo apt-get install -y glslang-tools

# Compile every .comp to a scratch dir and fail if any shader does not
# compile. This guards the regression that the default `glslangValidator -V`
# (SPIR-V 1.0) could not build the subgroup-using quantized kernels, which
# need --target-env vulkan1.1 (SPIR-V 1.3).
- name: Compile all shaders
working-directory: csrc/shaders
run: |
set -euo pipefail
out="$(mktemp -d)"
fail=0
for comp in *.comp; do
echo "Compiling $comp"
if ! glslangValidator --target-env vulkan1.1 -V "$comp" -o "$out/${comp%.comp}.spv"; then
echo "::error file=csrc/shaders/$comp::shader failed to compile"
fail=1
fi
done
[ "$fail" -eq 0 ] || { echo "One or more shaders failed to compile"; exit 1; }
echo "Compiled $(ls "$out"/*.spv | wc -l) shaders."

python-lint:
name: Python lint + syntax
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: "3.11"

- name: Install ruff
run: pip install ruff

# py_compile catches syntax errors without needing the compiled _C
# extension (which requires a Vulkan GPU + Kompute + custom PyTorch and
# cannot run on a stock CI runner).
- name: Syntax check
run: |
python -m py_compile setup.py persistent_pipeline.py \
torch_vulkan/__init__.py tests/*.py

- name: Ruff
run: ruff check .

# NOTE: The C++ extension build (CMake + Torch + Kompute) and the runtime test
# suite require a Vulkan-capable GPU and are not run here -- GitHub-hosted
# runners have no GPU. Build/test verification is done locally on the target
# hardware (AMD Radeon 890M / RADV). This workflow verifies what can be checked
# without a GPU: shader compilation and Python static checks.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,16 @@ attention + KV cache, kv_store.

```bash
# Requires a Vulkan 1.2+ driver (RADV recommended) and glslangValidator for shader build.
cd build && cmake .. && make -j$(nproc)

# 1. (Re)compile the GLSL shaders to SPIR-V (pre-built .spv files are committed,
# so this is only needed if you edit a shader). Targets Vulkan 1.1 / SPIR-V 1.3
# because the quantized kernels use subgroup ops.
bash csrc/shaders/compile.sh

# 2. Build the C++ extension.
cmake -S . -B build && cmake --build build -j

# 3. Install the Python package (build_ext also invokes CMake).
pip install -e .

python -c "import torch, torch_vulkan; print(torch_vulkan.device_name())"
Expand Down
18 changes: 17 additions & 1 deletion csrc/shaders/compile.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
#!/bin/bash
# Compile all GLSL compute shaders to SPIR-V
# Requires: glslangValidator (from Vulkan SDK or `pacman -S glslang`)
#
# We target Vulkan 1.1, which produces SPIR-V 1.3. Several of the quantized
# matmul kernels (matmul_q*k*, matmul_gpuq*) use GL_KHR_shader_subgroup
# reductions, and those subgroup ops require SPIR-V >= 1.3. The default
# `glslangValidator -V` emits SPIR-V 1.0 and fails on them, so the explicit
# --target-env is required to compile the full shader set.
set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
cd "$SCRIPT_DIR"

fail=0
for comp in *.comp; do
spv="${comp%.comp}.spv"
echo "Compiling $comp -> $spv"
glslangValidator -V "$comp" -o "$spv"
if ! glslangValidator --target-env vulkan1.1 -V "$comp" -o "$spv"; then
echo " ERROR: failed to compile $comp" >&2
fail=1
fi
done

if [ "$fail" -ne 0 ]; then
echo "One or more shaders failed to compile." >&2
exit 1
fi

echo "Done. $(ls *.spv | wc -l) shaders compiled."
6 changes: 6 additions & 0 deletions csrc/torch_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, VulkanGuardImpl);

// Python bindings
void set_shader_dir(const std::string& path) {
// Set the shader dir on the Kompute-based VulkanContext (used by all the
// registered ops). The raw VulkanEngine (used only by the not-yet-wired
// mm_raw path) is constructed lazily and reads its directory from the
// TORCH_VULKAN_SHADER_DIR env var, which __init__.py points at this same
// directory -- so we deliberately do NOT touch VulkanEngine::instance()
// here, to avoid eagerly creating a second Vulkan device at import time.
VulkanContext::instance().set_shader_dir(path);
}

Expand Down
12 changes: 11 additions & 1 deletion csrc/vulkan_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <fstream>
#include <stdexcept>
#include <cstring>
#include <cstdlib>
#include <algorithm>

namespace torch_vulkan {
Expand All @@ -12,7 +13,16 @@ VulkanEngine& VulkanEngine::instance() {
}

VulkanEngine::VulkanEngine() {
shaderDir_ = "/home/raz/projects/torch-vulkan/csrc/shaders/";
// The shader directory is normally set at import time via
// setShaderDir() (driven by torch_vulkan._set_shader_dir in __init__.py).
// As a fallback for direct/standalone use, honour TORCH_VULKAN_SHADER_DIR
// so the path is never hardcoded to a developer's machine.
if (const char* env = std::getenv("TORCH_VULKAN_SHADER_DIR")) {
shaderDir_ = env;
if (!shaderDir_.empty() && shaderDir_.back() != '/') {
shaderDir_ += '/';
}
}
initVulkan();
}

Expand Down
2 changes: 0 additions & 2 deletions persistent_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
"""

import torch
import time
import numpy as np


class PersistentLayerPipeline:
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[tool.ruff]
# torch-vulkan is a thin Python shim over a C++/Vulkan extension; keep linting
# focused on real problems (pyflakes + import sorting) rather than style churn.
line-length = 100
target-version = "py310"

[tool.ruff.lint]
select = ["E", "F", "I"]
# E402 (module-level import not at top of file) is intentionally allowed: the
# package must import `torch` first, then load the `_C` extension that registers
# the PrivateUse1 backend, and finally alias the module into `torch.vulkan` --
# all of which happen mid-module by design.
ignore = ["E402"]
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

import os
import subprocess
import sys

from setuptools import setup, Extension
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext


Expand Down
21 changes: 13 additions & 8 deletions tests/bench_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
Focuses on the hot ops: mm, add, gelu that repeat across layers.
"""

import sys
import os
import sys
import time

sys.path.insert(0, "/home/raz/builds/pytorch-gfx1150")
# Allow pointing at a custom-built PyTorch without hardcoding a developer's
# path. Set TORCH_VULKAN_PYTORCH_PATH if needed.
_custom_torch = os.environ.get("TORCH_VULKAN_PYTORCH_PATH")
if _custom_torch:
sys.path.insert(0, _custom_torch)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch

import torch_vulkan


Expand All @@ -23,12 +28,12 @@ def bench_mm_repeated(M, K, N, iters=20):
b = b_cpu.to("vulkan")

# Warmup (first call = cache miss)
c = torch.mm(a, b)
torch.mm(a, b)

# Benchmark (subsequent calls should hit cache)
t0 = time.perf_counter()
for _ in range(iters):
c = torch.mm(a, b)
torch.mm(a, b)
elapsed = (time.perf_counter() - t0) / iters
print(f" mm [{M}x{K}] @ [{K}x{N}]: {elapsed*1000:.2f} ms/call")
return elapsed
Expand All @@ -41,11 +46,11 @@ def bench_add_repeated(N, iters=20):
a = a_cpu.to("vulkan")
b = b_cpu.to("vulkan")

c = torch.add(a, b) # warmup
torch.add(a, b) # warmup

t0 = time.perf_counter()
for _ in range(iters):
c = torch.add(a, b)
torch.add(a, b)
elapsed = (time.perf_counter() - t0) / iters
print(f" add [{N}]: {elapsed*1000:.2f} ms/call")
return elapsed
Expand All @@ -56,11 +61,11 @@ def bench_gelu_repeated(N, iters=20):
a_cpu = torch.randn(N)
a = a_cpu.to("vulkan")

c = torch.nn.functional.gelu(a) # warmup
torch.nn.functional.gelu(a) # warmup

t0 = time.perf_counter()
for _ in range(iters):
c = torch.nn.functional.gelu(a)
torch.nn.functional.gelu(a)
elapsed = (time.perf_counter() - t0) / iters
print(f" gelu [{N}]: {elapsed*1000:.2f} ms/call")
return elapsed
Expand Down
10 changes: 7 additions & 3 deletions tests/test_algo_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
repeated dispatches with the same tensor buffers hit the cache.
"""

import sys
import os
import sys
import time

# Use the custom-built PyTorch
sys.path.insert(0, "/home/raz/builds/pytorch-gfx1150")
# Allow pointing at a custom-built PyTorch (e.g. a local APU build) without
# hardcoding a developer's path. Set TORCH_VULKAN_PYTORCH_PATH if needed.
_custom_torch = os.environ.get("TORCH_VULKAN_PYTORCH_PATH")
if _custom_torch:
sys.path.insert(0, _custom_torch)
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch

import torch_vulkan


Expand Down
22 changes: 18 additions & 4 deletions tests/test_mm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Smoke test: matrix multiply through the Vulkan backend."""

import torch

import torch_vulkan


Expand Down Expand Up @@ -35,17 +36,30 @@ def test_mm_small():
print("mm 2x2: PASS")


def test_cpu_fallback():
# relu isn't implemented in Vulkan yet — should fall back to CPU
a = torch.randn(16, device="vulkan")
def test_relu():
# relu IS wired to the Vulkan backend (relu.spv); verify both shape and
# values against the CPU reference.
a_cpu = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
a = a_cpu.to("vulkan")
b = torch.relu(a)
assert b.shape == a.shape
torch.testing.assert_close(b.cpu(), torch.relu(a_cpu), rtol=1e-3, atol=1e-3)
print("relu: PASS")


def test_unimplemented_op_falls_back_to_cpu():
# An op with no Vulkan impl (here: sign) should route through the boxed
# CPU fallback registered for PrivateUse1 and still return a valid result.
a = torch.randn(16, device="vulkan")
b = torch.sign(a)
assert b.shape == a.shape
print("CPU fallback: PASS")


if __name__ == "__main__":
test_vulkan_available()
test_mm_small()
test_mm_square()
test_cpu_fallback()
test_relu()
test_unimplemented_op_falls_back_to_cpu()
print("\nAll tests passed.")
34 changes: 23 additions & 11 deletions torch_vulkan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
print(torch_vulkan.is_available())
print(torch_vulkan.device_name())

# Create tensors on Vulkan device
a = torch.randn(64, 64, device="vulkan")
b = torch.randn(64, 64, device="vulkan")
c = torch.mm(a, b) # runs matmul.spv on GPU

# Or move existing tensors
x = torch.randn(128, 128)
x_vk = x.vulkan()
# Create on CPU, then move to the Vulkan device (the supported path):
a = torch.randn(64, 64).to("vulkan")
b = torch.randn(64, 64).to("vulkan")
c = torch.mm(a, b) # runs matmul_tiled.spv on GPU
c.cpu() # bring the result back

# torch.empty(..., device="vulkan") and torch.tensor(data, device="vulkan")
# also work. The torch.randn(..., device="vulkan") constructor and the
# .vulkan() method are only partially wired -- see the README Limitations
# section. Prefer .to("vulkan").
"""

import os

import torch

# Load the C++ extension — this registers PrivateUse1 as "vulkan"
Expand All @@ -33,13 +36,21 @@
for_tensor=True, for_module=True, for_storage=False
)
except Exception as e:
import warnings; warnings.warn(f"generate_methods_for_privateuse1_backend failed, using .vulkan() shim: {e}")
import warnings
warnings.warn(
f"generate_methods_for_privateuse1_backend failed, using .vulkan() shim: {e}"
)
torch.Tensor.vulkan = lambda self, *a, **k: self.to("vulkan")

# Point the shader loader at our bundled .spv files
# Point the shader loader at our bundled .spv files.
_shader_dir = os.path.join(os.path.dirname(__file__), "..", "csrc", "shaders")
if os.path.isdir(_shader_dir):
_C._set_shader_dir(os.path.abspath(_shader_dir))
_abs_shader_dir = os.path.abspath(_shader_dir)
_C._set_shader_dir(_abs_shader_dir)
# The raw VulkanEngine (used by the mm_raw path) is constructed lazily and
# reads its shader directory from this env var, so it resolves to the same
# bundled shaders instead of a hardcoded default.
os.environ.setdefault("TORCH_VULKAN_SHADER_DIR", _abs_shader_dir)


def is_available() -> bool:
Expand Down Expand Up @@ -81,4 +92,5 @@ def clear_algorithm_cache():
# Register as torch.vulkan so .to("vulkan") works
# PyTorch's PrivateUse1 dispatch does "import torch.<backend_name>"
import sys

sys.modules['torch.vulkan'] = sys.modules[__name__]
Loading