Skip to content

Commit 67abbd5

Browse files
committed
update
1 parent 1dc6631 commit 67abbd5

File tree

1 file changed

+82
-112
lines changed

1 file changed

+82
-112
lines changed

examples/dynamo/nvrtc_aot_plugin.py

Lines changed: 82 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,46 @@
1212

1313
import torch
1414

15-
# Import triton for kernel implementation
16-
import triton
17-
import triton.language as tl
18-
1915
import torch_tensorrt
2016

21-
# ============================================================================
22-
# Triton Kernel for Eager Execution
23-
# ============================================================================
17+
# CUDA kernel source (NVRTC) used by the torch custom op
18+
cu_code = """
19+
#include <cuda_fp16.h>
20+
21+
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
22+
__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input,
23+
__half* __restrict__ output,
24+
const int size) {
25+
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
26+
27+
if (idx < size) {
28+
const float x = __half2float(input[idx]);
29+
const float result = 1.0f / (1.0f + expf(-x));
30+
output[idx] = __float2half(result);
31+
}
32+
}
33+
"""
2434

35+
# Prepare NVRTC program, kernel, and stream once (simple eager path)
36+
from cuda.core.experimental import (
37+
Device as _CudaDevice,
38+
LaunchConfig as _LaunchConfig,
39+
Program as _CudaProgram,
40+
ProgramOptions as _CudaProgramOptions,
41+
launch as _cuda_launch,
42+
)
2543

26-
@triton.jit
27-
def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
28-
# Program ID determines the block of data each thread will process
29-
pid = tl.program_id(0)
30-
# Compute the range of elements that this thread block will work on
31-
block_start = pid * BLOCK_SIZE
32-
# Range of indices this thread will handle
33-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
34-
# Mask for boundary checking
35-
mask = offsets < n_elements
36-
# Load elements from the X tensor
37-
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
38-
# Convert to float32 for computation
39-
x_f32 = x.to(tl.float32)
40-
# Compute sigmoid: 1 / (1 + exp(-x))
41-
output = tl.sigmoid(x_f32)
42-
# Convert back to original dtype
43-
output_casted = output.to(x.dtype)
44-
# Store the result in Y
45-
tl.store(y_ptr + offsets, output_casted, mask=mask)
44+
_cuda_device = _CudaDevice()
45+
_cuda_device.set_current()
46+
_cuda_stream = _cuda_device.create_stream()
47+
_program_options = _CudaProgramOptions(
48+
std="c++17", arch=f"sm_{_cuda_device.arch}", include_path=["/usr/local/cuda/include"]
49+
)
50+
_program = _CudaProgram(cu_code, code_type="c++", options=_program_options)
51+
_module = _program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
52+
_kernel = _module.get_kernel("pointwise_sigmoid_kernel_nvrtc")
53+
54+
# Eager torch custom_op implemented using the CUDA kernel above (no Triton)
4655

4756

4857
# ============================================================================
@@ -52,20 +61,37 @@ def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr)
5261

5362
@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc]
5463
def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor:
55-
# Ensure the tensor is on the GPU
5664
assert X.is_cuda, "Tensor must be on CUDA device."
5765

58-
# Create output tensor
5966
Y = torch.empty_like(X)
67+
N = int(X.numel())
68+
69+
block = 256
6070

61-
# Define block size
62-
BLOCK_SIZE = 256
71+
grid_x = max(1, (N + block - 1) // block)
72+
config = _LaunchConfig(grid=(grid_x), block=(block))
6373

64-
# Grid of programs
65-
grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),)
74+
# Use PyTorch's current stream by wrapping it for cuda.core
75+
class _PyTorchStreamWrapper:
76+
def __init__(self, pt_stream):
77+
self.pt_stream = pt_stream
6678

67-
# Launch the kernel
68-
pointwise_sigmoid_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE)
79+
def __cuda_stream__(self):
80+
stream_id = self.pt_stream.cuda_stream
81+
return (0, stream_id)
82+
83+
pt_stream = torch.cuda.current_stream()
84+
s = _cuda_device.create_stream(_PyTorchStreamWrapper(pt_stream))
85+
86+
# Launch kernel with raw pointers as in cuda.core example
87+
_cuda_launch(
88+
s,
89+
config,
90+
_kernel,
91+
X.data_ptr(),
92+
Y.data_ptr(),
93+
N,
94+
)
6995

7096
return Y
7197

@@ -97,66 +123,6 @@ def sigmoid_autotune(
97123
return [trtp.AutoTuneCombination("FP16, FP16", "LINEAR")]
98124

99125

100-
# @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
101-
# def sigmoid_aot_triton_impl(
102-
# input: trtp.TensorDesc,
103-
# outputs: Tuple[trtp.TensorDesc],
104-
# tactic: int,
105-
# ) -> Tuple[
106-
# Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
107-
# ]:
108-
# print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (Triton)!!!")
109-
110-
# # Reuse the same Triton kernel we use for eager execution
111-
# src = triton.compiler.ASTSource(
112-
# fn=pointwise_sigmoid_kernel,
113-
# signature={
114-
# "x_ptr": "*fp16",
115-
# "y_ptr": "*fp16",
116-
# "n_elements": "i32",
117-
# "BLOCK_SIZE": "constexpr",
118-
# },
119-
# constexprs={"BLOCK_SIZE": 256},
120-
# )
121-
122-
# compiled_kernel = triton.compile(src)
123-
124-
# N = input.shape_expr.numel()
125-
# launch_params = trtp.KernelLaunchParams()
126-
# launch_params.grid_x = trtp.cdiv(N, 256)
127-
# launch_params.block_x = compiled_kernel.metadata.num_warps * 32
128-
# launch_params.shared_mem = compiled_kernel.metadata.shared
129-
130-
# extra_args = trtp.SymIntExprs(1)
131-
# extra_args[0] = trtp.SymInt32(N)
132-
133-
# print(compiled_kernel.asm["ptx"])
134-
135-
# return (
136-
# compiled_kernel.metadata.name,
137-
# compiled_kernel.asm["ptx"],
138-
# launch_params,
139-
# extra_args,
140-
# )
141-
142-
143-
cu_code = """
144-
#include <cuda_fp16.h>
145-
146-
// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
147-
__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input,
148-
__half* __restrict__ output,
149-
const int size) {
150-
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
151-
152-
if (idx < size) {
153-
const float x = __half2float(input[idx]);
154-
const float result = 1.0f / (1.0f + expf(-x));
155-
output[idx] = __float2half(result);
156-
}
157-
}
158-
"""
159-
160126

161127
@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
162128
def sigmoid_aot_nvrtc_impl(
@@ -166,22 +132,19 @@ def sigmoid_aot_nvrtc_impl(
166132
) -> Tuple[
167133
Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
168134
]:
169-
print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (NVRTC)!!!")
170135

171-
dev = Device()
172-
dev.set_current()
173-
program_options = ProgramOptions(
174-
std="c++17", arch=f"sm_{dev.arch}", include_path=["/usr/local/cuda/include"]
175-
)
176-
program = Program(cu_code, code_type="c++", options=program_options)
177-
mod = program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",))
178-
compiled_kernel = mod.code.decode("utf-8")
136+
compiled_kernel= _module.code.decode("utf-8")
137+
print(type(compiled_kernel))
179138
print(compiled_kernel)
180139

140+
# import pdb; pdb.set_trace()
141+
142+
181143
N = input.shape_expr.numel()
182144
launch_params = trtp.KernelLaunchParams()
183-
launch_params.grid_x = trtp.cdiv((N + 256 - 1), 256)
184-
launch_params.block_x = 256
145+
block = 256
146+
launch_params.grid_x = trtp.cdiv(N, block)
147+
launch_params.block_x = block
185148
launch_params.shared_mem = 0
186149

187150
extra_args = trtp.SymIntExprs(1)
@@ -217,31 +180,38 @@ class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
217180
"""
218181

219182
def forward(self, input: torch.Tensor) -> torch.Tensor:
220-
x = torch.mul(input, 2)
221-
y = torch.div(x, 2)
222-
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(y)
223-
a = torch.add(z, 1)
224-
return a
183+
184+
z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(input)
185+
return z
225186

226187

227188
if __name__ == "__main__":
228189
model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval()
229190
input = torch.randn(1, 1024, device="cuda", dtype=torch.float16)
230191

192+
print(torch.sigmoid(input))
193+
194+
print(model(input))
195+
231196
with torch_tensorrt.logging.debug():
232197
trt_inputs = [input]
233198
model_trt = torch_tensorrt.compile(
234199
model,
235200
inputs=trt_inputs,
201+
enabled_precisions={torch.float16},
236202
min_block_size=1,
237203
)
238204
print("Model compiled successfully!")
239205
print("Running inference with compiled model...")
206+
print("Compiled model output:")
207+
print(model_trt(input))
208+
print("Original model output:")
209+
print(model(input))
240210
with torch.no_grad():
241211
for i in range(10):
242212
res = model_trt(input)
243213
assert torch.allclose(
244214
res, model(input), rtol=1e-2, atol=1e-2
245215
), "Results do not match!"
246216

247-
print("Inference successful!")
217+
# print("Inference successful!")

0 commit comments

Comments
 (0)