Skip to content

Commit edb8b43

Browse files
authored
Resource Partitioner -- Cpu memory graph break (#3886)
1 parent a1d932a commit edb8b43

15 files changed

+1449
-4
lines changed

.github/workflows/build-test-linux-x86_64.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
cd tests/py
137137
cd dynamo
138138
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
139-
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
139+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_*
140140
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
141141
popd
142142
@@ -229,6 +229,8 @@ jobs:
229229
pushd .
230230
cd tests/py/dynamo
231231
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
232+
python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_*
233+
232234
popd
233235
234236
L1-dynamo-compile-tests:

.github/workflows/build-test-windows.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ jobs:
135135
pushd .
136136
cd tests/py/dynamo
137137
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_runtime_tests_results.xml runtime/test_000_*
138-
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/
138+
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_partitioning_tests_results.xml partitioning/test_000_*
139139
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l0_dynamo_core_lowering_tests_results.xml lowering/
140140
popd
141141
@@ -219,6 +219,7 @@ jobs:
219219
pushd .
220220
cd tests/py/dynamo
221221
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_tests_results.xml runtime/test_001_*
222+
../../../packaging/vc_env_helper.bat python -m pytest -ra -n 8 --junitxml=${RUNNER_TEST_RESULTS_DIR}/l1_dynamo_core_partitioning_tests_results.xml partitioning/test_001_*
222223
popd
223224
224225
L1-dynamo-compile-tests:
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
3+
.. _low_cpu_memory_compilation:
4+
5+
Low CPU Memory Compilation Example
6+
==================================
7+
8+
This example demonstrates compiling a model with a bounded CPU (host) memory
9+
budget using Torch-TensorRT Dynamo. Limiting host RAM use is helpful on
10+
memory-constrained machines or when compiling very large models.
11+
12+
Key notes:
13+
- The toy model below has roughly 430 MB of parameters. We set the CPU
14+
memory budget to 2 GiB. At compile time, only about 900 MB of host RAM
15+
may remain available. We expect at most 403 * 4 = 1612 MB of memory to be used by the model.
16+
So the model is partitioned into two subgraphs to fit the memory budget.
17+
18+
- Performance impact varies by model. When the number of TensorRT engines
19+
created is small, the impact is typically minimal.
20+
21+
"""
22+
23+
import torch
24+
import torch.nn as nn
25+
import torch.nn.functional as F
26+
import torch_tensorrt as torchtrt
27+
from torch_tensorrt.dynamo.conversion import CompilationSettings
28+
29+
30+
class net(nn.Module):
31+
def __init__(self):
32+
super().__init__()
33+
# Intentionally large layers to stress host memory during compilation.
34+
self.conv1 = nn.Conv2d(1024, 4096, 3, padding=1)
35+
self.bn1 = nn.BatchNorm2d(4096)
36+
self.conv2 = nn.Conv2d(4096, 1024, 3, padding=1)
37+
self.bn2 = nn.BatchNorm2d(1024)
38+
self.fc1 = nn.Linear(1024 * 56 * 56, 10)
39+
40+
def forward(self, x):
41+
x = self.conv1(x)
42+
x = self.bn1(x)
43+
x = F.relu(x)
44+
x = F.max_pool2d(x, (2, 2))
45+
x = self.conv2(x)
46+
x = self.bn2(x)
47+
x = F.relu(x)
48+
x = F.max_pool2d(x, (2, 2))
49+
x = torch.flatten(x, 1)
50+
return self.fc1(x)
51+
52+
53+
model = net().eval()
54+
model.to("cuda")
55+
inputs = [torch.randn((1, 1024, 224, 224)).to("cuda")]
56+
57+
enabled_precisions = {torch.float}
58+
use_python_runtime = False
59+
60+
compilation_options = {
61+
"use_python_runtime": use_python_runtime,
62+
"enabled_precisions": enabled_precisions,
63+
"min_block_size": 1,
64+
"immutable_weights": True,
65+
"reuse_cached_engines": False,
66+
"enable_resource_partitioning": True,
67+
"cpu_memory_budget": 2 * 1024 * 1024 * 1024, # 2 GiB in bytes
68+
}
69+
70+
settings = CompilationSettings(**compilation_options)
71+
with torchtrt.dynamo.Debugger(
72+
log_level="debug",
73+
logging_dir="/home/profile/logging/moe",
74+
engine_builder_monitor=False,
75+
):
76+
77+
exp_program = torch.export.export(model, tuple(inputs))
78+
trt_gm = torchtrt.dynamo.compile(
79+
exp_program,
80+
inputs=inputs,
81+
**compilation_options,
82+
)
83+
84+
# Expect two back-to-back TensorRT engines due to partitioning under the memory budget.
85+
print(trt_gm)
86+
87+
88+
"""
89+
You should be able to see two back-to-back TensorRT engines in the graph
90+
91+
Graph Structure:
92+
93+
Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
94+
...
95+
TRT Engine #1 - Submodule name: _run_on_acc_0_resource_split_0
96+
Engine Inputs: List[Tensor: (1, 1024, 224, 224)@float32]
97+
Number of Operators in Engine: 9
98+
Engine Outputs: List[Tensor: (1, 1024, 112, 112)@float32]
99+
...
100+
TRT Engine #2 - Submodule name: _run_on_acc_0_resource_split_1
101+
Engine Inputs: List[Tensor: (1, 1024, 112, 112)@float32]
102+
Number of Operators in Engine: 3
103+
Engine Outputs: List[Tensor: (1, 10)@float32]
104+
...
105+
Outputs: List[Tensor: (1, 10)@float32]
106+
107+
------------------------- Aggregate Stats -------------------------
108+
109+
Average Number of Operators per TRT Engine: 6.0
110+
Most Operators in a TRT Engine: 9
111+
112+
********** Recommendations **********
113+
114+
- For minimal graph segmentation, select min_block_size=9 which would generate 1 TRT engine(s)
115+
- For moderate graph segmentation, select min_block_size=6 which would generate 1 TRT engine(s)
116+
- The current level of graph segmentation is equivalent to selecting min_block_size=3 which generates 2 TRT engine(s)
117+
GraphModule(
118+
(_run_on_acc_0_resource_split_0): TorchTensorRTModule()
119+
(_run_on_acc_0_resource_split_1): TorchTensorRTModule()
120+
)
121+
122+
123+
124+
def forward(self, x):
125+
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
126+
_run_on_acc_0_resource_split_0 = self._run_on_acc_0_resource_split_0(x); x = None
127+
_run_on_acc_0_resource_split_1 = self._run_on_acc_0_resource_split_1(_run_on_acc_0_resource_split_0); _run_on_acc_0_resource_split_0 = None
128+
return pytree.tree_unflatten((_run_on_acc_0_resource_split_1,), self._out_spec)
129+
)
130+
"""

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
post_lowering,
4141
pre_export_lowering,
4242
)
43+
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
44+
resource_partition,
45+
)
4346
from torch_tensorrt.dynamo.utils import (
4447
deallocate_module,
4548
get_cpu_memory_usage,
@@ -105,6 +108,8 @@ def cross_compile_for_windows(
105108
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
106109
offload_module_to_cpu: bool = _defaults.OFFLOAD_MODULE_TO_CPU,
107110
use_distributed_mode_trace: bool = _defaults.USE_DISTRIBUTED_MODE_TRACE,
111+
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
112+
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
108113
**kwargs: Any,
109114
) -> torch.fx.GraphModule:
110115
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -179,6 +184,8 @@ def cross_compile_for_windows(
179184
tiling_optimization_level (str): The optimization level of tiling strategies. A higher level allows TensorRT to spend more time searching for better tiling strategy. We currently support ["none", "fast", "moderate", "full"].
180185
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
181186
use_distributed_mode_trace (bool): Using aot_autograd to trace the graph. This is enabled when DTensors or distributed tensors are present in distributed model
187+
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
188+
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
182189
**kwargs: Any,
183190
Returns:
184191
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -334,6 +341,8 @@ def cross_compile_for_windows(
334341
"tiling_optimization_level": tiling_optimization_level,
335342
"l2_limit_for_tiling": l2_limit_for_tiling,
336343
"use_distributed_mode_trace": use_distributed_mode_trace,
344+
"enable_resource_partitioning": enable_resource_partitioning,
345+
"cpu_memory_budget": cpu_memory_budget,
337346
}
338347

339348
# disable the following settings is not supported for cross compilation for windows feature
@@ -448,6 +457,8 @@ def compile(
448457
autocast_calibration_dataloader: Optional[
449458
torch.utils.data.DataLoader
450459
] = _defaults.AUTOCAST_CALIBRATION_DATALOADER,
460+
cpu_memory_budget: Optional[int] = _defaults.CPU_MEMORY_BUDGET,
461+
enable_resource_partitioning: bool = _defaults.ENABLE_RESOURCE_PARTITIONING,
451462
**kwargs: Any,
452463
) -> torch.fx.GraphModule:
453464
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -532,6 +543,8 @@ def compile(
532543
autocast_max_output_threshold (float): Maximum absolute value for node outputs, nodes with outputs greater than this value will remain in FP32. Default is 512.
533544
autocast_max_depth_of_reduction (Optional[int]): Maximum depth of reduction allowed in low precision. Nodes with higher reduction depths will remain in FP32. This helps prevent excessive accuracy loss in operations particularly sensitive to reduced precision, as higher-depth reductions may amplify computation errors in low precision formats. If not provided, infinity will be used. Default is None.
534545
autocast_calibration_dataloader (Optional[torch.utils.data.DataLoader]): The dataloader to use for autocast calibration. Default is None.
546+
enable_resource_partitioning (bool): Enable resource-aware partitioning. This is useful when the model is large and the CPU memory is limited.
547+
cpu_memory_budget (Optional[int]): The maximum amount of CPU memory to use for the compilation. If the compilation requires more memory than this budget, the compilation will fail.
535548
**kwargs: Any,
536549
Returns:
537550
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -732,6 +745,8 @@ def compile(
732745
"autocast_max_output_threshold": autocast_max_output_threshold,
733746
"autocast_max_depth_of_reduction": autocast_max_depth_of_reduction,
734747
"autocast_calibration_dataloader": autocast_calibration_dataloader,
748+
"enable_resource_partitioning": enable_resource_partitioning,
749+
"cpu_memory_budget": cpu_memory_budget,
735750
}
736751
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
737752
settings = CompilationSettings(**compilation_options)
@@ -905,6 +920,12 @@ def preserve_module_specs(
905920
require_full_compilation=settings.require_full_compilation,
906921
)
907922

923+
if settings.enable_resource_partitioning:
924+
partitioned_module = resource_partition(
925+
partitioned_module,
926+
cpu_memory_budget=settings.cpu_memory_budget,
927+
)
928+
908929
dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators
909930

910931
# The global partitioner leaves non-TRT nodes as-is
@@ -928,6 +949,7 @@ def preserve_module_specs(
928949
for attr in dir(gm):
929950
if attr.startswith("_frozen_param"):
930951
delattr(gm, attr)
952+
931953
for name, _ in partitioned_module.named_children():
932954
submodule = getattr(partitioned_module, name)
933955
# filter on the GraphModule
@@ -1390,7 +1412,7 @@ def convert_exported_program_to_serialized_trt_engine(
13901412
)
13911413

13921414
flattened_input_list = get_flat_args_with_check(
1393-
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
1415+
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
13941416
)[0]
13951417

13961418
try:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
AUTOCAST_MAX_OUTPUT_THRESHOLD = 512
6565
AUTOCAST_MAX_DEPTH_OF_REDUCTION = None
6666
AUTOCAST_CALIBRATION_DATALOADER = None
67+
ENABLE_RESOURCE_PARTITIONING = False
68+
CPU_MEMORY_BUDGET = None
6769

6870
if platform.system() == "Linux":
6971
import pwd

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AUTOCAST_MAX_DEPTH_OF_REDUCTION,
1515
AUTOCAST_MAX_OUTPUT_THRESHOLD,
1616
CACHE_BUILT_ENGINES,
17+
CPU_MEMORY_BUDGET,
1718
DISABLE_TF32,
1819
DLA_GLOBAL_DRAM_SIZE,
1920
DLA_LOCAL_DRAM_SIZE,
@@ -22,6 +23,7 @@
2223
ENABLE_AUTOCAST,
2324
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
2425
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
26+
ENABLE_RESOURCE_PARTITIONING,
2527
ENABLE_WEIGHT_STREAMING,
2628
ENABLED_PRECISIONS,
2729
ENGINE_CAPABILITY,
@@ -168,6 +170,8 @@ class CompilationSettings:
168170
autocast_calibration_dataloader: Optional[torch.utils.data.DataLoader] = (
169171
AUTOCAST_CALIBRATION_DATALOADER
170172
)
173+
enable_resource_partitioning: bool = ENABLE_RESOURCE_PARTITIONING
174+
cpu_memory_budget: int = CPU_MEMORY_BUDGET
171175

172176
def __getstate__(self) -> dict[str, Any]:
173177
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def partition_graph(self) -> torch.fx.GraphModule:
230230

231231
# Tag the accelerated nodes and split the graph accordingly
232232
self.tag(subgraphs)
233-
return self.split()
233+
return self.split(remove_tag=True)
234234

235235
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
236236
"""Generates starter nodes for partitioning + segmentation"""

0 commit comments

Comments
 (0)