From d2951fdeb51d5f657fcaa121c9aa0208f71b9994 Mon Sep 17 00:00:00 2001 From: wufeisheng Date: Thu, 2 Apr 2026 21:53:51 +0800 Subject: [PATCH] add blockwise cuda graph --- custom_ops/gpu_ops/helper.h | 28 +- fastdeploy/envs.py | 10 + .../graph_optimization/cuda_graph_op.py | 284 ++++++++++++++++++ fastdeploy/model_executor/layers/linear.py | 4 + .../model_executor/layers/normalization.py | 4 + fastdeploy/worker/gpu_model_runner.py | 52 ++++ fastdeploy/worker/gpu_worker.py | 3 + 7 files changed, 383 insertions(+), 2 deletions(-) create mode 100644 fastdeploy/model_executor/graph_optimization/cuda_graph_op.py diff --git a/custom_ops/gpu_ops/helper.h b/custom_ops/gpu_ops/helper.h index 83f3ad1077d..2ff623e8b83 100644 --- a/custom_ops/gpu_ops/helper.h +++ b/custom_ops/gpu_ops/helper.h @@ -52,6 +52,8 @@ namespace cub = hipcub; #include "env.h" #include "paddle/extension.h" #include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/memory/allocation/allocator_facade.h" +#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/phi/backends/custom/custom_context.h" #else @@ -371,7 +373,18 @@ inline json readJsonFromFile(const std::string &filePath) { inline paddle::Tensor GetEmptyTensor(const common::DDim &dims, const paddle::DataType &dtype, const paddle::Place &place) { - auto *allocator = paddle::GetAllocator(place); + phi::Allocator *allocator = nullptr; +#if defined(PADDLE_WITH_CUDA) + if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { + allocator = paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place) + .get(); + } else { + allocator = paddle::GetAllocator(place); + } +#else + allocator = paddle::GetAllocator(place); +#endif phi::DenseTensor dense_tensor; dense_tensor.Resize(dims); dense_tensor.AllocateFrom( @@ -383,7 +396,18 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim &dims, const common::DDim &strides, const paddle::DataType &dtype, const paddle::Place &place) { - auto *allocator = paddle::GetAllocator(place); + phi::Allocator *allocator = nullptr; +#if defined(PADDLE_WITH_CUDA) + if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) { + allocator = paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(place) + .get(); + } else { + allocator = paddle::GetAllocator(place); + } +#else + allocator = paddle::GetAllocator(place); +#endif phi::DenseTensor dense_tensor; dense_tensor.Resize(dims); dense_tensor.AllocateFrom( diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 0c7ac3e22b1..e4d3106be62 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -266,6 +266,16 @@ def _validate_split_kv_size(value: int) -> int: "FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool( int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1")) ), + # Whether to enable block-wise CUDA Graph capture/replay. + # When enabled, individual layer forward methods decorated with @block_wise_cuda_graph_wrap + # will be captured and replayed as CUDA Graphs for improved performance. + # Set to 1 to enable; defaults to 0 (disabled). + "FD_USE_BLOCK_WISE_CUDA_GRAPH": lambda: bool(int(os.getenv("FD_USE_BLOCK_WISE_CUDA_GRAPH", "0"))), + # Comma-separated list of token counts to pre-capture for block-wise CUDA Graphs. + # Used during the warmup phase to pre-capture graphs for these specific sizes. + # At runtime, token counts not in this list fall back to eager execution. + # Example: "1,2,4,8,16,32,64,128,256,512" + "FD_BLOCK_WISE_CUDA_GRAPH_SIZES": lambda: os.getenv("FD_BLOCK_WISE_CUDA_GRAPH_SIZES", "128,256,512,1024,2048"), } diff --git a/fastdeploy/model_executor/graph_optimization/cuda_graph_op.py b/fastdeploy/model_executor/graph_optimization/cuda_graph_op.py new file mode 100644 index 00000000000..3352e576c46 --- /dev/null +++ b/fastdeploy/model_executor/graph_optimization/cuda_graph_op.py @@ -0,0 +1,284 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import functools +import inspect +from typing import Callable, Optional, Sequence + +import paddle + +import fastdeploy + +# ---- Module-level state for pre-captured block-wise CUDA graphs ---- + +# When True, the wrapper is in the capture phase (during dummy_run) and +# will capture new graphs. When False, uncached keys fall back to eager. +_BLOCK_WISE_CAPTURING: bool = False + +# Registry of all shared-mode graph caches, for bulk clearing. +_ALL_SHARED_CACHES: list = [] + + +def set_block_wise_capturing(capturing: bool): + """Toggle the capture phase flag. Only capture graphs when this is True.""" + global _BLOCK_WISE_CAPTURING + _BLOCK_WISE_CAPTURING = capturing + + +def clear_all_block_wise_graphs(): + """Clear all shared block-wise graph caches (e.g. for RL weight updates).""" + for graphs, cinputs, coutputs in _ALL_SHARED_CACHES: + graphs.clear() + cinputs.clear() + coutputs.clear() + + +def block_wise_cuda_graph_wrap( + inputs: Sequence[str], + self_attrs: Sequence[str] = (), + key_fn: Optional[Callable[..., tuple]] = None, +): + """ + Method decorator that wraps a forward method with CUDA Graph capture/replay. + + On the first call for a given cache key (derived from tensor shapes/dtypes), + the decorated method is captured into a CUDA Graph. Subsequent calls with the + same key will replay the graph after updating input data pointers. + + When ``_BLOCK_WISE_CAPTURING`` is managed via ``set_block_wise_capturing``, + new graphs are only captured during the capture phase (dummy_run). At runtime, + uncached keys fall back to eager execution, avoiding expensive on-the-fly captures. + + When ``self_attrs`` is provided, the named tensor attributes of ``self`` + (e.g. ``weight``) are also tracked for pointer replacement, and the graph + cache is **shared across all instances** (closure-level). This allows layers + with identical computation but different weights to share a single captured + graph, dramatically reducing the total number of graphs from O(num_layers) + to O(num_unique_shapes). + + When ``self_attrs`` is empty (default), graphs are cached per instance. + + Output tensors from the capture phase are reused across replays — the graph + always writes to the same output memory. This avoids per-replay allocation + overhead. Callers must consume the output before the next replay of the same + graph (which is naturally satisfied in sequential layer-by-layer forward). + + Args: + inputs: Names of parameters that are input tensors to be tracked for + CUDA Graph pointer replacement. These must be parameter names of the + decorated method. Only non-None tensor arguments are tracked. + self_attrs: Attribute names on ``self`` that are tensor parameters to be + replaced via pointer replacement (e.g. ``["weight"]``). When non-empty, + enables cross-instance graph sharing. + key_fn: Optional callable to generate the cache key from method arguments. + Signature: key_fn(arg0, arg1, ...) with args in declaration order + (excluding self). Defaults to a key based on tensor shapes/dtypes. + + Example: + class MyNorm(nn.Layer): + @block_wise_cuda_graph_wrap( + inputs=["x", "residual"], + self_attrs=["weight"], # all layers share one graph + ) + def forward(self, x, residual=None): + return rms_norm(x, self.weight), residual + """ + + def decorator(method: Callable) -> Callable: + sig = inspect.signature(method) + params = list(sig.parameters.keys()) # ["self", "x", "residual_input", ...] + + for name in inputs: + if name not in params or name == "self": + raise ValueError( + f"cuda_graph_wrap: input '{name}' is not a parameter of " + f"{method.__qualname__}. Available: {[p for p in params if p != 'self']}" + ) + + # ---- Pre-compute at decoration time (runs once) ---- + + _EMPTY = inspect.Parameter.empty + _Tensor = paddle.Tensor + + # For each non-self param: (name, args_index, default_value) + # args_index is position in *args (0-based, since self is consumed by Python) + _param_info = tuple((p, i - 1, sig.parameters[p].default) for i, p in enumerate(params) if p != "self") + + # For each declared input tensor: (name, args_index) + _input_info = tuple((name, params.index(name) - 1) for name in inputs) + + _self_attr_names = tuple(self_attrs) + _shared = len(_self_attr_names) > 0 + + _use_custom_key = key_fn is not None + + # --- Cache storage --- + # When self_attrs is provided: closure-level (shared across all instances) + # When not: per-instance (stored in self.__dict__) + if _shared: + _shared_graphs = {} + _shared_cinputs = {} + _shared_coutputs = {} # stores actual result tensors (reused across replays) + _ALL_SHARED_CACHES.append((_shared_graphs, _shared_cinputs, _shared_coutputs)) + + # Per-instance attribute key names + _g = f"_cg_{method.__name__}_g" + _ci = f"_cg_{method.__name__}_ci" + _co = f"_cg_{method.__name__}_co" + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if not fastdeploy.envs.FD_USE_BLOCK_WISE_CUDA_GRAPH: + return method(self, *args, **kwargs) + + nargs = len(args) + + # Skip CUDA graph if any input tensor has a 0 in its shape + for a in args: + if isinstance(a, _Tensor) and 0 in a.shape: + return method(self, *args, **kwargs) + for v in kwargs.values(): + if isinstance(v, _Tensor) and 0 in v.shape: + return method(self, *args, **kwargs) + + # === Key generation: inline, no sig.bind === + if _use_custom_key: + # Resolve all args for custom key_fn + resolved = [] + for pname, aidx, default in _param_info: + if pname in kwargs: + resolved.append(kwargs[pname]) + elif aidx < nargs: + resolved.append(args[aidx]) + elif default is not _EMPTY: + resolved.append(default) + else: + resolved.append(None) + key = key_fn(*resolved) + else: + # Default: fast inline key from shapes/dtypes + _kp = [] + for pname, aidx, default in _param_info: + if pname in kwargs: + v = kwargs[pname] + elif aidx < nargs: + v = args[aidx] + else: + v = default + if isinstance(v, _Tensor): + _kp.append((tuple(v.shape), v.dtype)) + elif v is None: + _kp.append(None) + elif callable(v): + _kp.append(True) + # Include self_attrs shapes/dtypes in key + for attr_name in _self_attr_names: + attr = getattr(self, attr_name, None) + if attr is not None and isinstance(attr, _Tensor): + _kp.append((attr_name, tuple(attr.shape), attr.dtype)) + else: + _kp.append((attr_name, None)) + key = tuple(_kp) + + # === Get cache (shared or per-instance) === + if _shared: + graphs = _shared_graphs + cinputs = _shared_cinputs + coutputs = _shared_coutputs + else: + _d = self.__dict__ + try: + graphs = _d[_g] + cinputs = _d[_ci] + coutputs = _d[_co] + except KeyError: + graphs = {} + cinputs = {} + coutputs = {} + _d[_g] = graphs + _d[_ci] = cinputs + _d[_co] = coutputs + + if key not in graphs: + # === First encounter: only capture during capture phase === + if not _BLOCK_WISE_CAPTURING: + # Not in capture phase -- fall back to eager + return method(self, *args, **kwargs) + + # === Capture === + graph = paddle.device.cuda.graphs.CUDAGraph(enable_replace=True) + graphs[key] = graph + + ci = {} + for name, aidx in _input_info: + v = kwargs[name] if name in kwargs else (args[aidx] if aidx < nargs else None) + if v is not None and isinstance(v, _Tensor): + ci[name] = v.data_ptr() + + # Record self_attrs pointers for cross-instance replacement + for attr_name in _self_attr_names: + attr = getattr(self, attr_name, None) + if attr is not None and isinstance(attr, _Tensor): + ci[f"__attr_{attr_name}"] = attr.data_ptr() + + cinputs[key] = ci + + graph.capture_begin() + result = method(self, *args, **kwargs) + graph.capture_end() + + graph.replay() + + # Store the actual result for reuse. The graph always writes to + # the same output memory, so we return the same tensors on replay. + coutputs[key] = result + return result + else: + # === Replay path (HOT PATH) === + old_ptrs = [] + new_ptrs = [] + ci = cinputs[key] + + for name, aidx in _input_info: + v = kwargs[name] if name in kwargs else (args[aidx] if aidx < nargs else None) + if v is not None and name in ci: + old_ptrs.append(ci[name]) + new_ptr = v.data_ptr() + new_ptrs.append(new_ptr) + ci[name] = new_ptr + + # Replace self_attrs pointers (e.g. weight) + for attr_name in _self_attr_names: + attr_key = f"__attr_{attr_name}" + if attr_key in ci: + attr = getattr(self, attr_name, None) + if attr is not None: + old_ptrs.append(ci[attr_key]) + new_ptr = attr.data_ptr() + new_ptrs.append(new_ptr) + ci[attr_key] = new_ptr + + if old_ptrs: + graphs[key].replace_input_ptrs(old_ptrs, new_ptrs) + graphs[key].replay() + + # Reuse the output tensors from capture — graph wrote fresh + # data to the same memory, no allocation needed. + return coutputs[key] + + return wrapper + + return decorator diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2bee885ff43..e1171087987 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -25,6 +25,9 @@ decode_alltoall_transpose, tensor_model_parallel_all_reduce, ) +from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( + block_wise_cuda_graph_wrap, +) from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase from fastdeploy.model_executor.utils import ( default_weight_loader, @@ -245,6 +248,7 @@ def load_state_dict(self, state_dict: dict): bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key))) self.bias.set_value(bias_tensor) + @block_wise_cuda_graph_wrap(inputs=["x"], self_attrs=["weight", "weight_scale_inv", "bias"]) def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor: """ Forward function for Linear. diff --git a/fastdeploy/model_executor/layers/normalization.py b/fastdeploy/model_executor/layers/normalization.py index 14e248e0a72..9e55d3aafd9 100644 --- a/fastdeploy/model_executor/layers/normalization.py +++ b/fastdeploy/model_executor/layers/normalization.py @@ -21,6 +21,9 @@ from paddle import nn from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( + block_wise_cuda_graph_wrap, +) from fastdeploy.platforms import current_platform if current_platform.is_gcu(): @@ -203,6 +206,7 @@ def allgather(self, out, token_num): paddle.distributed.all_gather(multi_outs, out, self.tp_group) return multi_outs[:token_num, :] + @block_wise_cuda_graph_wrap(inputs=["x", "residual_input"], self_attrs=["weight"]) def forward( self, x, diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c0e689735d4..9ea1275ddfc 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -27,6 +27,7 @@ from paddle import nn from paddleformers.utils.log import logger +import fastdeploy from fastdeploy.config import FDConfig from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.request import ImagePosition, Request, RequestType @@ -2678,6 +2679,13 @@ def clear_parameters(self, pid): # Clear CUDAGraph if self.use_cudagraph: self.model.clear_grpah_opt_backend() + # Clear block-wise CUDA graphs + if fastdeploy.envs.FD_USE_BLOCK_WISE_CUDA_GRAPH: + from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( + clear_all_block_wise_graphs, + ) + + clear_all_block_wise_graphs() # Clear parameters and Send single self.dynamic_weight_manager.clear_parameters( pid, self.fd_config.parallel_config.shutdown_comm_group_if_worker_idle @@ -3106,3 +3114,47 @@ def initialize_routing_replay_manager(self): block_table=self.share_inputs["block_tables"], total_block_num=self.num_gpu_blocks, ) + + def capture_block_wise_graphs(self) -> None: + """ + Independent capture loop for block-wise CUDA graphs. + Pre-captures graphs for designated token counts so that at runtime, + matching sizes replay the graph while other sizes fall back to eager. + """ + if not fastdeploy.envs.FD_USE_BLOCK_WISE_CUDA_GRAPH: + return + + from fastdeploy.model_executor.graph_optimization.cuda_graph_op import ( + set_block_wise_capturing, + ) + + # Parse capture sizes from env var + sizes_str = fastdeploy.envs.FD_BLOCK_WISE_CUDA_GRAPH_SIZES + capture_sizes = sorted([int(s.strip()) for s in sizes_str.split(",") if s.strip()], reverse=True) + if not capture_sizes: + logger.warning("FD_BLOCK_WISE_CUDA_GRAPH_SIZES is empty, skipping block-wise CUDA graph capture") + return + + logger.info(f"Block-wise CUDA graph capture starting for sizes: {sorted(capture_sizes)}") + time_before_capture = time.perf_counter() + + set_block_wise_capturing(True) + try: + for num_tokens in capture_sizes: + batch_size = min(num_tokens, self.scheduler_config.max_num_seqs) + if batch_size < 1: + batch_size = 1 + self._dummy_run( + num_tokens=num_tokens, + batch_size=batch_size, + in_capturing=False, + ) + logger.info(f"Block-wise CUDA graph captured for num_tokens={num_tokens}") + finally: + set_block_wise_capturing(False) + + time_after_capture = time.perf_counter() + logger.info( + f"Block-wise CUDA graph capturing took {time_after_capture - time_before_capture:.3f} seconds " + f"for {len(capture_sizes)} sizes" + ) diff --git a/fastdeploy/worker/gpu_worker.py b/fastdeploy/worker/gpu_worker.py index aebf3f21111..1cdb423c0b3 100644 --- a/fastdeploy/worker/gpu_worker.py +++ b/fastdeploy/worker/gpu_worker.py @@ -249,6 +249,9 @@ def graph_optimize_and_warm_up_model(self) -> None: # Capture CUDAGraph for decode phase (all modes) self.model_runner.capture_model() + # Block-wise CUDA graph capture (independent loop) + self.model_runner.capture_block_wise_graphs() + # Deterministic mode: reset RNG and share_inputs after warmup. # Warmup _dummy_run() calls consume CUDA RNG state and leave stale # data (infer_seed, stop_flags, seq_lens, etc.) in share_inputs.