Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions custom_ops/gpu_ops/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ namespace cub = hipcub;
#include "env.h"
#include "paddle/extension.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/memory/allocation/allocator_facade.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/custom/custom_context.h"
#else
Expand Down Expand Up @@ -371,7 +373,18 @@ inline json readJsonFromFile(const std::string &filePath) {
inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
const paddle::DataType &dtype,
const paddle::Place &place) {
auto *allocator = paddle::GetAllocator(place);
phi::Allocator *allocator = nullptr;
#if defined(PADDLE_WITH_CUDA)
if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
allocator = paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place)
.get();
} else {
allocator = paddle::GetAllocator(place);
}
#else
allocator = paddle::GetAllocator(place);
#endif
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(
Expand All @@ -383,7 +396,18 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
const common::DDim &strides,
const paddle::DataType &dtype,
const paddle::Place &place) {
auto *allocator = paddle::GetAllocator(place);
phi::Allocator *allocator = nullptr;
#if defined(PADDLE_WITH_CUDA)
if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
allocator = paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place)
.get();
} else {
allocator = paddle::GetAllocator(place);
}
#else
allocator = paddle::GetAllocator(place);
#endif
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(
Expand Down
10 changes: 10 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,16 @@ def _validate_split_kv_size(value: int) -> int:
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
),
# Whether to enable block-wise CUDA Graph capture/replay.
# When enabled, individual layer forward methods decorated with @block_wise_cuda_graph_wrap
# will be captured and replayed as CUDA Graphs for improved performance.
# Set to 1 to enable; defaults to 0 (disabled).
"FD_USE_BLOCK_WISE_CUDA_GRAPH": lambda: bool(int(os.getenv("FD_USE_BLOCK_WISE_CUDA_GRAPH", "0"))),
# Comma-separated list of token counts to pre-capture for block-wise CUDA Graphs.
# Used during the warmup phase to pre-capture graphs for these specific sizes.
# At runtime, token counts not in this list fall back to eager execution.
# Example: "1,2,4,8,16,32,64,128,256,512"
"FD_BLOCK_WISE_CUDA_GRAPH_SIZES": lambda: os.getenv("FD_BLOCK_WISE_CUDA_GRAPH_SIZES", "128,256,512,1024,2048"),
}


Expand Down
284 changes: 284 additions & 0 deletions fastdeploy/model_executor/graph_optimization/cuda_graph_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import functools
import inspect
from typing import Callable, Optional, Sequence

import paddle

import fastdeploy

# ---- Module-level state for pre-captured block-wise CUDA graphs ----

# When True, the wrapper is in the capture phase (during dummy_run) and
# will capture new graphs. When False, uncached keys fall back to eager.
_BLOCK_WISE_CAPTURING: bool = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 全局变量 _BLOCK_WISE_CAPTURING_ALL_SHARED_CACHES 是模块级状态,在多进程场景下(如 Tensor Parallel)每个进程会有独立副本,这是预期行为。

但请确认:

  1. 是否存在同一进程内多 GPU 的场景?
  2. 如果有,是否需要 per-device 的状态隔离?

当前实现假设每个进程只负责一个 GPU,如果这是设计约束,建议在文档中明确说明。


# Registry of all shared-mode graph caches, for bulk clearing.
_ALL_SHARED_CACHES: list = []


def set_block_wise_capturing(capturing: bool):
"""Toggle the capture phase flag. Only capture graphs when this is True."""
global _BLOCK_WISE_CAPTURING
_BLOCK_WISE_CAPTURING = capturing


def clear_all_block_wise_graphs():
"""Clear all shared block-wise graph caches (e.g. for RL weight updates)."""
for graphs, cinputs, coutputs in _ALL_SHARED_CACHES:
graphs.clear()
cinputs.clear()
coutputs.clear()


def block_wise_cuda_graph_wrap(
inputs: Sequence[str],
self_attrs: Sequence[str] = (),
key_fn: Optional[Callable[..., tuple]] = None,
):
"""
Method decorator that wraps a forward method with CUDA Graph capture/replay.

On the first call for a given cache key (derived from tensor shapes/dtypes),
the decorated method is captured into a CUDA Graph. Subsequent calls with the
same key will replay the graph after updating input data pointers.

When ``_BLOCK_WISE_CAPTURING`` is managed via ``set_block_wise_capturing``,
new graphs are only captured during the capture phase (dummy_run). At runtime,
uncached keys fall back to eager execution, avoiding expensive on-the-fly captures.

When ``self_attrs`` is provided, the named tensor attributes of ``self``
(e.g. ``weight``) are also tracked for pointer replacement, and the graph
cache is **shared across all instances** (closure-level). This allows layers
with identical computation but different weights to share a single captured
graph, dramatically reducing the total number of graphs from O(num_layers)
to O(num_unique_shapes).

When ``self_attrs`` is empty (default), graphs are cached per instance.

Output tensors from the capture phase are reused across replays — the graph
always writes to the same output memory. This avoids per-replay allocation
overhead. Callers must consume the output before the next replay of the same
graph (which is naturally satisfied in sequential layer-by-layer forward).

Args:
inputs: Names of parameters that are input tensors to be tracked for
CUDA Graph pointer replacement. These must be parameter names of the
decorated method. Only non-None tensor arguments are tracked.
self_attrs: Attribute names on ``self`` that are tensor parameters to be
replaced via pointer replacement (e.g. ``["weight"]``). When non-empty,
enables cross-instance graph sharing.
key_fn: Optional callable to generate the cache key from method arguments.
Signature: key_fn(arg0, arg1, ...) with args in declaration order
(excluding self). Defaults to a key based on tensor shapes/dtypes.

Example:
class MyNorm(nn.Layer):
@block_wise_cuda_graph_wrap(
inputs=["x", "residual"],
self_attrs=["weight"], # all layers share one graph
)
def forward(self, x, residual=None):
return rms_norm(x, self.weight), residual
"""

def decorator(method: Callable) -> Callable:
sig = inspect.signature(method)
params = list(sig.parameters.keys()) # ["self", "x", "residual_input", ...]

for name in inputs:
if name not in params or name == "self":
raise ValueError(
f"cuda_graph_wrap: input '{name}' is not a parameter of "
f"{method.__qualname__}. Available: {[p for p in params if p != 'self']}"
)

# ---- Pre-compute at decoration time (runs once) ----

_EMPTY = inspect.Parameter.empty
_Tensor = paddle.Tensor

# For each non-self param: (name, args_index, default_value)
# args_index is position in *args (0-based, since self is consumed by Python)
_param_info = tuple((p, i - 1, sig.parameters[p].default) for i, p in enumerate(params) if p != "self")

# For each declared input tensor: (name, args_index)
_input_info = tuple((name, params.index(name) - 1) for name in inputs)

_self_attr_names = tuple(self_attrs)
_shared = len(_self_attr_names) > 0

_use_custom_key = key_fn is not None

# --- Cache storage ---
# When self_attrs is provided: closure-level (shared across all instances)
# When not: per-instance (stored in self.__dict__)
if _shared:
_shared_graphs = {}
_shared_cinputs = {}
_shared_coutputs = {} # stores actual result tensors (reused across replays)
_ALL_SHARED_CACHES.append((_shared_graphs, _shared_cinputs, _shared_coutputs))

# Per-instance attribute key names
_g = f"_cg_{method.__name__}_g"
_ci = f"_cg_{method.__name__}_ci"
_co = f"_cg_{method.__name__}_co"

@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if not fastdeploy.envs.FD_USE_BLOCK_WISE_CUDA_GRAPH:
return method(self, *args, **kwargs)

nargs = len(args)

# Skip CUDA graph if any input tensor has a 0 in its shape
for a in args:
if isinstance(a, _Tensor) and 0 in a.shape:
return method(self, *args, **kwargs)
for v in kwargs.values():
if isinstance(v, _Tensor) and 0 in v.shape:
return method(self, *args, **kwargs)

# === Key generation: inline, no sig.bind ===
if _use_custom_key:
# Resolve all args for custom key_fn
resolved = []
for pname, aidx, default in _param_info:
if pname in kwargs:
resolved.append(kwargs[pname])
elif aidx < nargs:
resolved.append(args[aidx])
elif default is not _EMPTY:
resolved.append(default)
else:
resolved.append(None)
key = key_fn(*resolved)
else:
# Default: fast inline key from shapes/dtypes
_kp = []
for pname, aidx, default in _param_info:
if pname in kwargs:
v = kwargs[pname]
elif aidx < nargs:
v = args[aidx]
else:
v = default
if isinstance(v, _Tensor):
_kp.append((tuple(v.shape), v.dtype))
elif v is None:
_kp.append(None)
elif callable(v):
_kp.append(True)
# Include self_attrs shapes/dtypes in key
for attr_name in _self_attr_names:
attr = getattr(self, attr_name, None)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议self_attrs 中指定的属性不存在于 self 对象时,缓存 key 会包含 (attr_name, None),但后续 replay 路径中 getattr(self, attr_name, None) 返回 None 时会跳过指针替换。

这在功能上不会出错,但可能导致同一个 key 在不同实例上表现不一致(一个有属性一个没有),建议在装饰时或首次调用时校验 self_attrs 中的属性确实存在于对象上。

# 建议在 capture 阶段添加校验
for attr_name in _self_attr_names:
    if not hasattr(self, attr_name):
        raise AttributeError(f"self_attrs 指定的属性 '{attr_name}' 不存在于 {type(self).__name__}")

if attr is not None and isinstance(attr, _Tensor):
_kp.append((attr_name, tuple(attr.shape), attr.dtype))
else:
_kp.append((attr_name, None))
key = tuple(_kp)

# === Get cache (shared or per-instance) ===
if _shared:
graphs = _shared_graphs
cinputs = _shared_cinputs
coutputs = _shared_coutputs
else:
_d = self.__dict__
try:
graphs = _d[_g]
cinputs = _d[_ci]
coutputs = _d[_co]
except KeyError:
graphs = {}
cinputs = {}
coutputs = {}
_d[_g] = graphs
_d[_ci] = cinputs
_d[_co] = coutputs

if key not in graphs:
# === First encounter: only capture during capture phase ===
if not _BLOCK_WISE_CAPTURING:
# Not in capture phase -- fall back to eager
return method(self, *args, **kwargs)

# === Capture ===
graph = paddle.device.cuda.graphs.CUDAGraph(enable_replace=True)
graphs[key] = graph

ci = {}
for name, aidx in _input_info:
v = kwargs[name] if name in kwargs else (args[aidx] if aidx < nargs else None)
if v is not None and isinstance(v, _Tensor):
ci[name] = v.data_ptr()

# Record self_attrs pointers for cross-instance replacement
for attr_name in _self_attr_names:
attr = getattr(self, attr_name, None)
if attr is not None and isinstance(attr, _Tensor):
ci[f"__attr_{attr_name}"] = attr.data_ptr()

cinputs[key] = ci

graph.capture_begin()
result = method(self, *args, **kwargs)
graph.capture_end()

graph.replay()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议capture_end() 之后立即调用 replay() 是为了确保输出 tensor 有正确的数据。

但如果 capture 过程中发生 CUDA 错误(如 OOM),capture_end() 可能会抛出异常,导致 graph 状态不一致且被存入 graphs[key]

建议将 graph 存入缓存的操作移到 capture_end() 成功之后:

graph.capture_begin()
result = method(self, *args, **kwargs)
graph.capture_end()
graph.replay()
# 仅在成功后存入缓存
graphs[key] = graph
cinputs[key] = ci
coutputs[key] = result


# Store the actual result for reuse. The graph always writes to
# the same output memory, so we return the same tensors on replay.
coutputs[key] = result
return result
else:
# === Replay path (HOT PATH) ===
old_ptrs = []
new_ptrs = []
ci = cinputs[key]

for name, aidx in _input_info:
v = kwargs[name] if name in kwargs else (args[aidx] if aidx < nargs else None)
if v is not None and name in ci:
old_ptrs.append(ci[name])
new_ptr = v.data_ptr()
new_ptrs.append(new_ptr)
ci[name] = new_ptr

# Replace self_attrs pointers (e.g. weight)
for attr_name in _self_attr_names:
attr_key = f"__attr_{attr_name}"
if attr_key in ci:
attr = getattr(self, attr_name, None)
if attr is not None:
old_ptrs.append(ci[attr_key])
new_ptr = attr.data_ptr()
new_ptrs.append(new_ptr)
ci[attr_key] = new_ptr

if old_ptrs:
graphs[key].replace_input_ptrs(old_ptrs, new_ptrs)
graphs[key].replay()

# Reuse the output tensors from capture — graph wrote fresh
# data to the same memory, no allocation needed.
return coutputs[key]

return wrapper

return decorator
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
decode_alltoall_transpose,
tensor_model_parallel_all_reduce,
)
from fastdeploy.model_executor.graph_optimization.cuda_graph_op import (
block_wise_cuda_graph_wrap,
)
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.utils import (
default_weight_loader,
Expand Down Expand Up @@ -245,6 +248,7 @@ def load_state_dict(self, state_dict: dict):
bias_tensor = paddle.to_tensor(get_tensor(state_dict.pop(self.bias_key)))
self.bias.set_value(bias_tensor)

@block_wise_cuda_graph_wrap(inputs=["x"], self_attrs=["weight", "weight_scale_inv", "bias"])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 self_attrs 包含 weight_scale_inv,但该属性仅在量化场景下存在(由 QuantMethodBase.create_weights 创建)。

在非量化模式下,getattr(self, 'weight_scale_inv', None) 会返回 None,装饰器会将 (attr_name, None) 加入 key。这本身不会导致错误,但:

  1. 会导致量化/非量化 Linear 层无法共享同一个 captured graph(即使 shape 相同)
  2. 语义上 self_attrs 应该只包含确实存在的属性

建议考虑将量化层使用单独的装饰器配置,或在运行时动态过滤不存在的属性。

def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
"""
Forward function for Linear.
Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from paddle import nn

from fastdeploy.model_executor.forward_meta import ForwardMeta
from fastdeploy.model_executor.graph_optimization.cuda_graph_op import (
block_wise_cuda_graph_wrap,
)
from fastdeploy.platforms import current_platform

if current_platform.is_gcu():
Expand Down Expand Up @@ -203,6 +206,7 @@ def allgather(self, out, token_num):
paddle.distributed.all_gather(multi_outs, out, self.tp_group)
return multi_outs[:token_num, :]

@block_wise_cuda_graph_wrap(inputs=["x", "residual_input"], self_attrs=["weight"])
def forward(
self,
x,
Expand Down
Loading
Loading