1212
1313import torch
1414
15- # Import triton for kernel implementation
16- import triton
17- import triton .language as tl
18-
1915import torch_tensorrt
2016
21- # ============================================================================
22- # Triton Kernel for Eager Execution
23- # ============================================================================
17+ # CUDA kernel source (NVRTC) used by the torch custom op
18+ cu_code = """
19+ #include <cuda_fp16.h>
20+
21+ // Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
22+ __global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input,
23+ __half* __restrict__ output,
24+ const int size) {
25+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
26+
27+ if (idx < size) {
28+ const float x = __half2float(input[idx]);
29+ const float result = 1.0f / (1.0f + expf(-x));
30+ output[idx] = __float2half(result);
31+ }
32+ }
33+ """
2434
35+ # Prepare NVRTC program, kernel, and stream once (simple eager path)
36+ from cuda .core .experimental import (
37+ Device as _CudaDevice ,
38+ LaunchConfig as _LaunchConfig ,
39+ Program as _CudaProgram ,
40+ ProgramOptions as _CudaProgramOptions ,
41+ launch as _cuda_launch ,
42+ )
2543
26- @triton .jit
27- def pointwise_sigmoid_kernel (x_ptr , y_ptr , n_elements , BLOCK_SIZE : tl .constexpr ):
28- # Program ID determines the block of data each thread will process
29- pid = tl .program_id (0 )
30- # Compute the range of elements that this thread block will work on
31- block_start = pid * BLOCK_SIZE
32- # Range of indices this thread will handle
33- offsets = block_start + tl .arange (0 , BLOCK_SIZE )
34- # Mask for boundary checking
35- mask = offsets < n_elements
36- # Load elements from the X tensor
37- x = tl .load (x_ptr + offsets , mask = mask , other = 0.0 )
38- # Convert to float32 for computation
39- x_f32 = x .to (tl .float32 )
40- # Compute sigmoid: 1 / (1 + exp(-x))
41- output = tl .sigmoid (x_f32 )
42- # Convert back to original dtype
43- output_casted = output .to (x .dtype )
44- # Store the result in Y
45- tl .store (y_ptr + offsets , output_casted , mask = mask )
44+ _cuda_device = _CudaDevice ()
45+ _cuda_device .set_current ()
46+ _cuda_stream = _cuda_device .create_stream ()
47+ _program_options = _CudaProgramOptions (
48+ std = "c++17" , arch = f"sm_{ _cuda_device .arch } " , include_path = ["/usr/local/cuda/include" ]
49+ )
50+ _program = _CudaProgram (cu_code , code_type = "c++" , options = _program_options )
51+ _module = _program .compile ("ptx" , name_expressions = ("pointwise_sigmoid_kernel_nvrtc" ,))
52+ _kernel = _module .get_kernel ("pointwise_sigmoid_kernel_nvrtc" )
53+
54+ # Eager torch custom_op implemented using the CUDA kernel above (no Triton)
4655
4756
4857# ============================================================================
@@ -52,20 +61,37 @@ def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr)
5261
5362@torch .library .custom_op ("pointwise_sigmoid_ops::pointwise_sigmoid" , mutates_args = ()) # type: ignore[misc]
5463def pointwise_sigmoid (X : torch .Tensor ) -> torch .Tensor :
55- # Ensure the tensor is on the GPU
5664 assert X .is_cuda , "Tensor must be on CUDA device."
5765
58- # Create output tensor
5966 Y = torch .empty_like (X )
67+ N = int (X .numel ())
68+
69+ block = 256
6070
61- # Define block size
62- BLOCK_SIZE = 256
71+ grid_x = max ( 1 , ( N + block - 1 ) // block )
72+ config = _LaunchConfig ( grid = ( grid_x ), block = ( block ))
6373
64- # Grid of programs
65- grid = lambda meta : (triton .cdiv (X .numel (), meta ["BLOCK_SIZE" ]),)
74+ # Use PyTorch's current stream by wrapping it for cuda.core
75+ class _PyTorchStreamWrapper :
76+ def __init__ (self , pt_stream ):
77+ self .pt_stream = pt_stream
6678
67- # Launch the kernel
68- pointwise_sigmoid_kernel [grid ](X , Y , X .numel (), BLOCK_SIZE = BLOCK_SIZE )
79+ def __cuda_stream__ (self ):
80+ stream_id = self .pt_stream .cuda_stream
81+ return (0 , stream_id )
82+
83+ pt_stream = torch .cuda .current_stream ()
84+ s = _cuda_device .create_stream (_PyTorchStreamWrapper (pt_stream ))
85+
86+ # Launch kernel with raw pointers as in cuda.core example
87+ _cuda_launch (
88+ s ,
89+ config ,
90+ _kernel ,
91+ X .data_ptr (),
92+ Y .data_ptr (),
93+ N ,
94+ )
6995
7096 return Y
7197
@@ -97,66 +123,6 @@ def sigmoid_autotune(
97123 return [trtp .AutoTuneCombination ("FP16, FP16" , "LINEAR" )]
98124
99125
100- # @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid")
101- # def sigmoid_aot_triton_impl(
102- # input: trtp.TensorDesc,
103- # outputs: Tuple[trtp.TensorDesc],
104- # tactic: int,
105- # ) -> Tuple[
106- # Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs
107- # ]:
108- # print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (Triton)!!!")
109-
110- # # Reuse the same Triton kernel we use for eager execution
111- # src = triton.compiler.ASTSource(
112- # fn=pointwise_sigmoid_kernel,
113- # signature={
114- # "x_ptr": "*fp16",
115- # "y_ptr": "*fp16",
116- # "n_elements": "i32",
117- # "BLOCK_SIZE": "constexpr",
118- # },
119- # constexprs={"BLOCK_SIZE": 256},
120- # )
121-
122- # compiled_kernel = triton.compile(src)
123-
124- # N = input.shape_expr.numel()
125- # launch_params = trtp.KernelLaunchParams()
126- # launch_params.grid_x = trtp.cdiv(N, 256)
127- # launch_params.block_x = compiled_kernel.metadata.num_warps * 32
128- # launch_params.shared_mem = compiled_kernel.metadata.shared
129-
130- # extra_args = trtp.SymIntExprs(1)
131- # extra_args[0] = trtp.SymInt32(N)
132-
133- # print(compiled_kernel.asm["ptx"])
134-
135- # return (
136- # compiled_kernel.metadata.name,
137- # compiled_kernel.asm["ptx"],
138- # launch_params,
139- # extra_args,
140- # )
141-
142-
143- cu_code = """
144- #include <cuda_fp16.h>
145-
146- // Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x))
147- __global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input,
148- __half* __restrict__ output,
149- const int size) {
150- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
151-
152- if (idx < size) {
153- const float x = __half2float(input[idx]);
154- const float result = 1.0f / (1.0f + expf(-x));
155- output[idx] = __float2half(result);
156- }
157- }
158- """
159-
160126
161127@trtp .aot_impl ("pointwise_sigmoid_ops::pointwise_sigmoid" )
162128def sigmoid_aot_nvrtc_impl (
@@ -166,22 +132,19 @@ def sigmoid_aot_nvrtc_impl(
166132) -> Tuple [
167133 Union [str , bytes ], Union [str , bytes ], trtp .KernelLaunchParams , trtp .SymExprs
168134]:
169- print ("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (NVRTC)!!!" )
170135
171- dev = Device ()
172- dev .set_current ()
173- program_options = ProgramOptions (
174- std = "c++17" , arch = f"sm_{ dev .arch } " , include_path = ["/usr/local/cuda/include" ]
175- )
176- program = Program (cu_code , code_type = "c++" , options = program_options )
177- mod = program .compile ("ptx" , name_expressions = ("pointwise_sigmoid_kernel_nvrtc" ,))
178- compiled_kernel = mod .code .decode ("utf-8" )
136+ compiled_kernel = _module .code .decode ("utf-8" )
137+ print (type (compiled_kernel ))
179138 print (compiled_kernel )
180139
140+ # import pdb; pdb.set_trace()
141+
142+
181143 N = input .shape_expr .numel ()
182144 launch_params = trtp .KernelLaunchParams ()
183- launch_params .grid_x = trtp .cdiv ((N + 256 - 1 ), 256 )
184- launch_params .block_x = 256
145+ block = 256
146+ launch_params .grid_x = trtp .cdiv (N , block )
147+ launch_params .block_x = block
185148 launch_params .shared_mem = 0
186149
187150 extra_args = trtp .SymIntExprs (1 )
@@ -217,31 +180,38 @@ class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module):
217180 """
218181
219182 def forward (self , input : torch .Tensor ) -> torch .Tensor :
220- x = torch .mul (input , 2 )
221- y = torch .div (x , 2 )
222- z = torch .ops .pointwise_sigmoid_ops .pointwise_sigmoid (y )
223- a = torch .add (z , 1 )
224- return a
183+
184+ z = torch .ops .pointwise_sigmoid_ops .pointwise_sigmoid (input )
185+ return z
225186
226187
227188if __name__ == "__main__" :
228189 model = PointwiseSigmoidModel_WithTRTWrapper ().to ("cuda" ).eval ()
229190 input = torch .randn (1 , 1024 , device = "cuda" , dtype = torch .float16 )
230191
192+ print (torch .sigmoid (input ))
193+
194+ print (model (input ))
195+
231196 with torch_tensorrt .logging .debug ():
232197 trt_inputs = [input ]
233198 model_trt = torch_tensorrt .compile (
234199 model ,
235200 inputs = trt_inputs ,
201+ enabled_precisions = {torch .float16 },
236202 min_block_size = 1 ,
237203 )
238204 print ("Model compiled successfully!" )
239205 print ("Running inference with compiled model..." )
206+ print ("Compiled model output:" )
207+ print (model_trt (input ))
208+ print ("Original model output:" )
209+ print (model (input ))
240210 with torch .no_grad ():
241211 for i in range (10 ):
242212 res = model_trt (input )
243213 assert torch .allclose (
244214 res , model (input ), rtol = 1e-2 , atol = 1e-2
245215 ), "Results do not match!"
246216
247- print ("Inference successful!" )
217+ # print("Inference successful!")
0 commit comments