Skip to content

Commit 1dc6631

Browse files
committed
example: using nvrtc kernel for aot plugin
1 parent c9859b6 commit 1dc6631

File tree

1 file changed

+247
-0
lines changed

1 file changed

+247
-0
lines changed
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
"""
2+
Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue.
3+
4+
This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op()
5+
fails to compile operations that use fp16 (half-precision) tensors.
6+
7+
The issue occurs because the JIT plugin generator doesn't properly declare format
8+
support for fp16 data types in the generated TensorRT plugin.
9+
"""
10+
11+
from typing import List, Tuple, Union
12+
13+
import torch
14+
15+
# Import triton for kernel implementation
16+
import triton
17+
import triton.language as tl
18+
19+
import torch_tensorrt
20+
21+
# ============================================================================
22+
# Triton Kernel for Eager Execution
23+
# ============================================================================
24+
25+
26+
@triton.jit
27+
def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
28+
# Program ID determines the block of data each thread will process
29+
pid = tl.program_id(0)
30+
# Compute the range of elements that this thread block will work on
31+
block_start = pid * BLOCK_SIZE
32+
# Range of indices this thread will handle
33+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
34+
# Mask for boundary checking
35+
mask = offsets < n_elements
36+
# Load elements from the X tensor
37+
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
38+
# Convert to float32 for computation
39+
x_f32 = x.to(tl.float32)
40+
# Compute sigmoid: 1 / (1 + exp(-x))
41+
output = tl.sigmoid(x_f32)
42+
# Convert back to original dtype
43+
output_casted = output.to(x.dtype)
44+
# Store the result in Y
45+
tl.store(y_ptr + offsets, output_casted, mask=mask)
46+
47+
48+
# ============================================================================
49+
# Custom Op Registration
50+
# ============================================================================
51+
52+
53+
@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
54+
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
55+
# Ensure the tensor is on the GPU
56+
assert X.is_cuda, "Tensor must be on CUDA device."
57+
58+
# Create output tensor
59+
Y = torch.empty_like(X)
60+
61+
# Define block size
62+
BLOCK_SIZE = 256
63+
64+
# Grid of programs
65+
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
66+
67+
# Launch the kernel
68+
pointwise_sigmoid_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE)
69+
70+
return Y
71+
72+
73+
@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid")
74+
def _(input: torch.Tensor) -> torch.Tensor:
75+
"""Fake implementation for TorchDynamo tracing of base operation."""
76+
return torch.empty_like(input)
77+
78+
79+
# ============================================================================
80+
# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16
81+
# ============================================================================
82+
83+
import tensorrt.plugin as trtp
84+
from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions
85+
86+
87+
@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid")
88+
def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]:
89+
return (input.like(),)
90+
91+
92+
@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid")
93+
def sigmoid_autotune(
94+
input: trtp.TensorDesc,
95+
outputs: Tuple[trtp.TensorDesc],
96+
) -> List[trtp.AutoTuneCombination]:
97+
return [trtp.AutoTuneCombination("FP16, FP16", "LINEAR")]
98+
99+
100+
# @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
101+
# def sigmoid_aot_triton_impl(
102+
# input: trtp.TensorDesc,
103+
# outputs: Tuple[trtp.TensorDesc],
104+
# tactic: int,
105+
# ) -> Tuple[
106+
# Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
107+
# ]:
108+
# print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (Triton)!!!")
109+
110+
# # Reuse the same Triton kernel we use for eager execution
111+
# src = triton.compiler.ASTSource(
112+
# fn=pointwise_sigmoid_kernel,
113+
# signature={
114+
# "x_ptr": "*fp16",
115+
# "y_ptr": "*fp16",
116+
# "n_elements": "i32",
117+
# "BLOCK_SIZE": "constexpr",
118+
# },
119+
# constexprs={"BLOCK_SIZE": 256},
120+
# )
121+
122+
# compiled_kernel = triton.compile(src)
123+
124+
# N = input.shape_expr.numel()
125+
# launch_params = trtp.KernelLaunchParams()
126+
# launch_params.grid_x = trtp.cdiv(N, 256)
127+
# launch_params.block_x = compiled_kernel.metadata.num_warps * 32
128+
# launch_params.shared_mem = compiled_kernel.metadata.shared
129+
130+
# extra_args = trtp.SymIntExprs(1)
131+
# extra_args[0] = trtp.SymInt32(N)
132+
133+
# print(compiled_kernel.asm["ptx"])
134+
135+
# return (
136+
# compiled_kernel.metadata.name,
137+
# compiled_kernel.asm["ptx"],
138+
# launch_params,
139+
# extra_args,
140+
# )
141+
142+
143+
cu_code = """
144+
#include <cuda_fp16.h>
145+
146+
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
147+
__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input,
148+
__half* __restrict__ output,
149+
const int size) {
150+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
151+
152+
if (idx < size) {
153+
const float x = __half2float(input[idx]);
154+
const float result = 1.0f / (1.0f + expf(-x));
155+
output[idx] = __float2half(result);
156+
}
157+
}
158+
"""
159+
160+
161+
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
162+
def sigmoid_aot_nvrtc_impl(
163+
input: trtp.TensorDesc,
164+
outputs: Tuple[trtp.TensorDesc],
165+
tactic: int,
166+
) -> Tuple[
167+
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
168+
]:
169+
print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (NVRTC)!!!")
170+
171+
dev = Device()
172+
dev.set_current()
173+
program_options = ProgramOptions(
174+
std="c++17", arch=f"sm_{dev.arch}", include_path=["/usr/local/cuda/include"]
175+
)
176+
program = Program(cu_code, code_type="c++", options=program_options)
177+
mod = program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
178+
compiled_kernel = mod.code.decode("utf-8")
179+
print(compiled_kernel)
180+
181+
N = input.shape_expr.numel()
182+
launch_params = trtp.KernelLaunchParams()
183+
launch_params.grid_x = trtp.cdiv((N + 256 - 1), 256)
184+
launch_params.block_x = 256
185+
launch_params.shared_mem = 0
186+
187+
extra_args = trtp.SymIntExprs(1)
188+
extra_args[0] = trtp.SymInt32(N)
189+
190+
return (
191+
"pointwise_sigmoid_kernel_nvrtc",
192+
compiled_kernel,
193+
launch_params,
194+
extra_args,
195+
)
196+
197+
198+
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
199+
"pointwise_sigmoid_ops::pointwise_sigmoid",
200+
supports_dynamic_shapes=True,
201+
requires_output_allocator=False,
202+
)
203+
204+
205+
# ============================================================================
206+
# Test Model
207+
# ============================================================================
208+
209+
210+
class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
211+
"""
212+
Test model that uses the TRT wrapper with custom_op() registration.
213+
214+
When compiled with torch_tensorrt.compile() using fp16 inputs, this will
215+
fail with: "could not find any supported formats consistent with input/output
216+
data types"
217+
"""
218+
219+
def forward(self, input: torch.Tensor) -> torch.Tensor:
220+
x = torch.mul(input, 2)
221+
y = torch.div(x, 2)
222+
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(y)
223+
a = torch.add(z, 1)
224+
return a
225+
226+
227+
if __name__ == "__main__":
228+
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
229+
input = torch.randn(1, 1024, device="cuda", dtype=torch.float16)
230+
231+
with torch_tensorrt.logging.debug():
232+
trt_inputs = [input]
233+
model_trt = torch_tensorrt.compile(
234+
model,
235+
inputs=trt_inputs,
236+
min_block_size=1,
237+
)
238+
print("Model compiled successfully!")
239+
print("Running inference with compiled model...")
240+
with torch.no_grad():
241+
for i in range(10):
242+
res = model_trt(input)
243+
assert torch.allclose(
244+
res, model(input), rtol=1e-2, atol=1e-2
245+
), "Results do not match!"
246+
247+
print("Inference successful!")

0 commit comments

Comments
 (0)