|
| 1 | +""" |
| 2 | +Minimal reproducible example demonstrating TensorRT fp16 custom_op() issue. |
| 3 | +
|
| 4 | +This module shows the bug where torch_tensorrt.dynamo.conversion.plugins.custom_op() |
| 5 | +fails to compile operations that use fp16 (half-precision) tensors. |
| 6 | +
|
| 7 | +The issue occurs because the JIT plugin generator doesn't properly declare format |
| 8 | +support for fp16 data types in the generated TensorRT plugin. |
| 9 | +""" |
| 10 | + |
| 11 | +from typing import List, Tuple, Union |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +# Import triton for kernel implementation |
| 16 | +import triton |
| 17 | +import triton.language as tl |
| 18 | + |
| 19 | +import torch_tensorrt |
| 20 | + |
| 21 | +# ============================================================================ |
| 22 | +# Triton Kernel for Eager Execution |
| 23 | +# ============================================================================ |
| 24 | + |
| 25 | + |
| 26 | +@triton.jit |
| 27 | +def pointwise_sigmoid_kernel(x_ptr, y_ptr, n_elements, BLOCK_SIZE: tl.constexpr): |
| 28 | + # Program ID determines the block of data each thread will process |
| 29 | + pid = tl.program_id(0) |
| 30 | + # Compute the range of elements that this thread block will work on |
| 31 | + block_start = pid * BLOCK_SIZE |
| 32 | + # Range of indices this thread will handle |
| 33 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 34 | + # Mask for boundary checking |
| 35 | + mask = offsets < n_elements |
| 36 | + # Load elements from the X tensor |
| 37 | + x = tl.load(x_ptr + offsets, mask=mask, other=0.0) |
| 38 | + # Convert to float32 for computation |
| 39 | + x_f32 = x.to(tl.float32) |
| 40 | + # Compute sigmoid: 1 / (1 + exp(-x)) |
| 41 | + output = tl.sigmoid(x_f32) |
| 42 | + # Convert back to original dtype |
| 43 | + output_casted = output.to(x.dtype) |
| 44 | + # Store the result in Y |
| 45 | + tl.store(y_ptr + offsets, output_casted, mask=mask) |
| 46 | + |
| 47 | + |
| 48 | +# ============================================================================ |
| 49 | +# Custom Op Registration |
| 50 | +# ============================================================================ |
| 51 | + |
| 52 | + |
| 53 | +@torch.library.custom_op("pointwise_sigmoid_ops::pointwise_sigmoid", mutates_args=()) # type: ignore[misc] |
| 54 | +def pointwise_sigmoid(X: torch.Tensor) -> torch.Tensor: |
| 55 | + # Ensure the tensor is on the GPU |
| 56 | + assert X.is_cuda, "Tensor must be on CUDA device." |
| 57 | + |
| 58 | + # Create output tensor |
| 59 | + Y = torch.empty_like(X) |
| 60 | + |
| 61 | + # Define block size |
| 62 | + BLOCK_SIZE = 256 |
| 63 | + |
| 64 | + # Grid of programs |
| 65 | + grid = lambda meta: (triton.cdiv(X.numel(), meta["BLOCK_SIZE"]),) |
| 66 | + |
| 67 | + # Launch the kernel |
| 68 | + pointwise_sigmoid_kernel[grid](X, Y, X.numel(), BLOCK_SIZE=BLOCK_SIZE) |
| 69 | + |
| 70 | + return Y |
| 71 | + |
| 72 | + |
| 73 | +@torch.library.register_fake("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 74 | +def _(input: torch.Tensor) -> torch.Tensor: |
| 75 | + """Fake implementation for TorchDynamo tracing of base operation.""" |
| 76 | + return torch.empty_like(input) |
| 77 | + |
| 78 | + |
| 79 | +# ============================================================================ |
| 80 | +# TensorRT Wrapper with custom_op() - THIS FAILS WITH FP16 |
| 81 | +# ============================================================================ |
| 82 | + |
| 83 | +import tensorrt.plugin as trtp |
| 84 | +from cuda.core.experimental import Device, LaunchConfig, Program, ProgramOptions |
| 85 | + |
| 86 | + |
| 87 | +@trtp.register("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 88 | +def sigmoid_plugin_desc(input: trtp.TensorDesc) -> Tuple[trtp.TensorDesc]: |
| 89 | + return (input.like(),) |
| 90 | + |
| 91 | + |
| 92 | +@trtp.autotune("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 93 | +def sigmoid_autotune( |
| 94 | + input: trtp.TensorDesc, |
| 95 | + outputs: Tuple[trtp.TensorDesc], |
| 96 | +) -> List[trtp.AutoTuneCombination]: |
| 97 | + return [trtp.AutoTuneCombination("FP16, FP16", "LINEAR")] |
| 98 | + |
| 99 | + |
| 100 | +# @trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 101 | +# def sigmoid_aot_triton_impl( |
| 102 | +# input: trtp.TensorDesc, |
| 103 | +# outputs: Tuple[trtp.TensorDesc], |
| 104 | +# tactic: int, |
| 105 | +# ) -> Tuple[ |
| 106 | +# Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs |
| 107 | +# ]: |
| 108 | +# print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (Triton)!!!") |
| 109 | + |
| 110 | +# # Reuse the same Triton kernel we use for eager execution |
| 111 | +# src = triton.compiler.ASTSource( |
| 112 | +# fn=pointwise_sigmoid_kernel, |
| 113 | +# signature={ |
| 114 | +# "x_ptr": "*fp16", |
| 115 | +# "y_ptr": "*fp16", |
| 116 | +# "n_elements": "i32", |
| 117 | +# "BLOCK_SIZE": "constexpr", |
| 118 | +# }, |
| 119 | +# constexprs={"BLOCK_SIZE": 256}, |
| 120 | +# ) |
| 121 | + |
| 122 | +# compiled_kernel = triton.compile(src) |
| 123 | + |
| 124 | +# N = input.shape_expr.numel() |
| 125 | +# launch_params = trtp.KernelLaunchParams() |
| 126 | +# launch_params.grid_x = trtp.cdiv(N, 256) |
| 127 | +# launch_params.block_x = compiled_kernel.metadata.num_warps * 32 |
| 128 | +# launch_params.shared_mem = compiled_kernel.metadata.shared |
| 129 | + |
| 130 | +# extra_args = trtp.SymIntExprs(1) |
| 131 | +# extra_args[0] = trtp.SymInt32(N) |
| 132 | + |
| 133 | +# print(compiled_kernel.asm["ptx"]) |
| 134 | + |
| 135 | +# return ( |
| 136 | +# compiled_kernel.metadata.name, |
| 137 | +# compiled_kernel.asm["ptx"], |
| 138 | +# launch_params, |
| 139 | +# extra_args, |
| 140 | +# ) |
| 141 | + |
| 142 | + |
| 143 | +cu_code = """ |
| 144 | +#include <cuda_fp16.h> |
| 145 | +
|
| 146 | +// Simple pointwise Sigmoid kernel: f(x) = 1 / (1 + exp(-x)) |
| 147 | +__global__ void pointwise_sigmoid_kernel_nvrtc(const __half* __restrict__ input, |
| 148 | + __half* __restrict__ output, |
| 149 | + const int size) { |
| 150 | + const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 151 | +
|
| 152 | + if (idx < size) { |
| 153 | + const float x = __half2float(input[idx]); |
| 154 | + const float result = 1.0f / (1.0f + expf(-x)); |
| 155 | + output[idx] = __float2half(result); |
| 156 | + } |
| 157 | +} |
| 158 | +""" |
| 159 | + |
| 160 | + |
| 161 | +@trtp.aot_impl("pointwise_sigmoid_ops::pointwise_sigmoid") |
| 162 | +def sigmoid_aot_nvrtc_impl( |
| 163 | + input: trtp.TensorDesc, |
| 164 | + outputs: Tuple[trtp.TensorDesc], |
| 165 | + tactic: int, |
| 166 | +) -> Tuple[ |
| 167 | + Union[str, bytes], Union[str, bytes], trtp.KernelLaunchParams, trtp.SymExprs |
| 168 | +]: |
| 169 | + print("WE ARE NOW GENERATING THE PTX FOR THE PLUGIN (NVRTC)!!!") |
| 170 | + |
| 171 | + dev = Device() |
| 172 | + dev.set_current() |
| 173 | + program_options = ProgramOptions( |
| 174 | + std="c++17", arch=f"sm_{dev.arch}", include_path=["/usr/local/cuda/include"] |
| 175 | + ) |
| 176 | + program = Program(cu_code, code_type="c++", options=program_options) |
| 177 | + mod = program.compile("ptx", name_expressions=("pointwise_sigmoid_kernel_nvrtc",)) |
| 178 | + compiled_kernel = mod.code.decode("utf-8") |
| 179 | + print(compiled_kernel) |
| 180 | + |
| 181 | + N = input.shape_expr.numel() |
| 182 | + launch_params = trtp.KernelLaunchParams() |
| 183 | + launch_params.grid_x = trtp.cdiv((N + 256 - 1), 256) |
| 184 | + launch_params.block_x = 256 |
| 185 | + launch_params.shared_mem = 0 |
| 186 | + |
| 187 | + extra_args = trtp.SymIntExprs(1) |
| 188 | + extra_args[0] = trtp.SymInt32(N) |
| 189 | + |
| 190 | + return ( |
| 191 | + "pointwise_sigmoid_kernel_nvrtc", |
| 192 | + compiled_kernel, |
| 193 | + launch_params, |
| 194 | + extra_args, |
| 195 | + ) |
| 196 | + |
| 197 | + |
| 198 | +torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter( |
| 199 | + "pointwise_sigmoid_ops::pointwise_sigmoid", |
| 200 | + supports_dynamic_shapes=True, |
| 201 | + requires_output_allocator=False, |
| 202 | +) |
| 203 | + |
| 204 | + |
| 205 | +# ============================================================================ |
| 206 | +# Test Model |
| 207 | +# ============================================================================ |
| 208 | + |
| 209 | + |
| 210 | +class PointwiseSigmoidModel_WithTRTWrapper(torch.nn.Module): |
| 211 | + """ |
| 212 | + Test model that uses the TRT wrapper with custom_op() registration. |
| 213 | +
|
| 214 | + When compiled with torch_tensorrt.compile() using fp16 inputs, this will |
| 215 | + fail with: "could not find any supported formats consistent with input/output |
| 216 | + data types" |
| 217 | + """ |
| 218 | + |
| 219 | + def forward(self, input: torch.Tensor) -> torch.Tensor: |
| 220 | + x = torch.mul(input, 2) |
| 221 | + y = torch.div(x, 2) |
| 222 | + z = torch.ops.pointwise_sigmoid_ops.pointwise_sigmoid(y) |
| 223 | + a = torch.add(z, 1) |
| 224 | + return a |
| 225 | + |
| 226 | + |
| 227 | +if __name__ == "__main__": |
| 228 | + model = PointwiseSigmoidModel_WithTRTWrapper().to("cuda").eval() |
| 229 | + input = torch.randn(1, 1024, device="cuda", dtype=torch.float16) |
| 230 | + |
| 231 | + with torch_tensorrt.logging.debug(): |
| 232 | + trt_inputs = [input] |
| 233 | + model_trt = torch_tensorrt.compile( |
| 234 | + model, |
| 235 | + inputs=trt_inputs, |
| 236 | + min_block_size=1, |
| 237 | + ) |
| 238 | + print("Model compiled successfully!") |
| 239 | + print("Running inference with compiled model...") |
| 240 | + with torch.no_grad(): |
| 241 | + for i in range(10): |
| 242 | + res = model_trt(input) |
| 243 | + assert torch.allclose( |
| 244 | + res, model(input), rtol=1e-2, atol=1e-2 |
| 245 | + ), "Results do not match!" |
| 246 | + |
| 247 | + print("Inference successful!") |
0 commit comments