Summary
I am seeing a hard segfault (no Python exception) during tvm.compile(...) for a CUDA target. The crash consistently occurs inside the TIR pass:
tvm::tir::transform::InjectPTXLDG32(bool)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
tvm::tir::BufferStore::BufferStore(...)
The input IRModule is produced by converting a PyTorch torch.export program using tvm.relax.frontend.torch.from_exported_program. The PyTorch model is intentionally small (Linear(4,4)) and returns a tuple of tensors: (torch.tril(x), torch.triu(x)).
This looks like a bug in the InjectPTXLDG32 rewrite logic, or an unsafe assumption in the pass leading to a null/invalid BufferStore construction.
Environment
From the repro output:
- TVM version:
0.22.0
- TVM commit:
9dbf3f22ff6f44962472f9af310fda368ca85ef2
- LLVM:
17.0.6
- Python:
3.10.16 (from stack paths)
- NumPy:
2.2.6
- PyTorch:
2.9.0+cu128
- CUDA GPU:
NVIDIA RTX A6000 (sm_86)
Target string used:
cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32
Minimal Repro Script
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import tir
def print_env_info():
print("==== Environment Info ====")
print("TVM version:", getattr(tvm, "__version__", "unknown"))
try:
print("TVM git commit:", tvm.support.libinfo().get("GIT_COMMIT_HASH", "unknown"))
except Exception:
print("TVM git commit: unknown")
try:
print("TVM LLVM version:", tvm.support.libinfo().get("LLVM_VERSION", "unknown"))
except Exception:
print("TVM LLVM version: unknown")
print("Python (numpy) version:", np.__version__)
print("PyTorch version:", torch.__version__)
print("CUDA available (torch):", torch.cuda.is_available())
if torch.cuda.is_available():
try:
print("CUDA device:", torch.cuda.get_device_name(0))
except Exception:
pass
print("==========================\n")
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4, 4)
def forward(self, x):
x = self.linear(x)
return torch.tril(x), torch.triu(x)
def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
mod = mod.to("cpu").eval()
x = x.to("cpu")
ep = torch.export.export(mod, (x,))
from tvm.relax.frontend.torch import from_exported_program
return from_exported_program(ep)
def main():
print_env_info()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this repro, but torch.cuda.is_available() is False")
target_str = "cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 -thread_warp_size=32"
target = tvm.target.Target(target_str)
relax_pipeline = "default"
tir_pipeline = "default"
model = MyModel()
x = torch.zeros((1, 4), dtype=torch.float32)
print("[repro] exporting torch -> relax ...")
ir_mod = export_to_relax(model, x)
disabled_pass = [
"DeadCodeElimination",
"CanonicalizeBindings",
"Simplify",
"UnrollLoop",
"VectorizeLoop",
"StorageRewrite",
"RemoveNoOp",
"LoopPartition",
]
pass_config = {
"relax.FuseOps.max_depth": 2,
"relax.lift_transform_params.consume_params": 1,
"tir.disable_storage_rewrite": 1,
"tir.disable_vectorize": 1,
"tir.instrument_bound_checkers": 1,
"tir.merge_static_smem": 1,
"tir.noalias": 1,
"tir.ptx_ldg32": 1,
"tir.use_async_copy": 1,
}
pc_kwargs = {
"opt_level": 3,
"disabled_pass": disabled_pass,
"config": pass_config,
}
print("[repro] target:", target)
print("[repro] relax_pipeline:", relax_pipeline)
print("[repro] tir_pipeline:", tir_pipeline)
print("[repro] opt_level:", pc_kwargs["opt_level"])
print("[repro] disabled_pass:", disabled_pass)
print("[repro] PassContext.config keys:", sorted(pass_config.keys()))
print("[repro] compiling with tvm.compile ...")
with tvm.transform.PassContext(**pc_kwargs):
_ = tvm.compile(
ir_mod,
target=target,
relax_pipeline=relax_pipeline,
tir_pipeline=tir_pipeline,
)
print("[repro] compile finished (no crash).")
if __name__ == "__main__":
main()
Actual Behavior
Segfault during compilation:
!!!!!!! Segfault encountered !!!!!!!
...
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)
This is a hard crash (core dumped), not a recoverable error.
Expected Behavior
tvm.compile(...) should either:
- successfully compile the module, or
- raise a normal Python exception / diagnostic if some pass config is invalid,
but it should not segfault.
Triage
Summary
I am seeing a hard segfault (no Python exception) during
tvm.compile(...)for a CUDA target. The crash consistently occurs inside the TIR pass:tvm::tir::transform::InjectPTXLDG32(bool)tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)tvm::tir::BufferStore::BufferStore(...)The input IRModule is produced by converting a PyTorch
torch.exportprogram usingtvm.relax.frontend.torch.from_exported_program. The PyTorch model is intentionally small (Linear(4,4)) and returns a tuple of tensors:(torch.tril(x), torch.triu(x)).This looks like a bug in the
InjectPTXLDG32rewrite logic, or an unsafe assumption in the pass leading to a null/invalidBufferStoreconstruction.Environment
From the repro output:
0.22.09dbf3f22ff6f44962472f9af310fda368ca85ef217.0.63.10.16(from stack paths)2.2.62.9.0+cu128NVIDIA RTX A6000(sm_86)Target string used:
Minimal Repro Script
Actual Behavior
Segfault during compilation:
This is a hard crash (core dumped), not a recoverable error.
Expected Behavior
tvm.compile(...)should either:but it should not segfault.
Triage