[PyTorch] torch.compile support for permutation functions#2686
[PyTorch] torch.compile support for permutation functions#2686pggPL wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
41e22ef to
8159d26
Compare
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch |
Greptile SummaryConverted MoE permutation operations from Key changes:
Note: FP8 QuantizedTensor inputs are explicitly not supported under torch.compile (runtime error raised), as stated in PR description. Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[User calls moe_permute/moe_unpermute] --> B{torch.compile?}
B -->|No| C[Direct custom op call]
B -->|Yes| D[torch.ops.te_moe.* custom ops]
D --> E[Forward Op]
E --> F[register_fake for shape inference]
E --> G[Actual implementation]
G --> H{FP8 QuantizedTensor?}
H -->|Yes| I[Passthrough in __torch_dispatch__]
H -->|No| J[Normal tensor processing]
I --> K[Handle FP8 internally]
J --> K
K --> L[setup_context saves state]
L --> M[Backward Op via register_autograd]
M --> N[Custom backward implementation]
subgraph "Custom Ops"
E
F
G
L
M
N
end
subgraph "QuantizedTensor Handling"
H
I
K
end
Last reviewed commit: 981beb4 |
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
| _moe_permute_index_map_workspace = None | ||
| _moe_permute_index_map_max_expanded_token_num = 0 |
There was a problem hiding this comment.
global mutable state may cause issues with concurrent execution or re-compilation
The workspace variables are shared module-level state that gets mutated. In torch.compile with multiple threads or parallel compilation, this could lead to race conditions. Consider thread-local storage or passing these as function arguments if possible.
Description
This PR adds
torch.compile(fullgraph=True)support for MoE permutation operations (moe_permute,moe_unpermute,moe_sort_chunks_by_index) by converting alltorch.autograd.Functionimplementations to PyTorch custom operators usingtorch.library.custom_op.Note that this PR does not add torch.compile support for QuantizedTensor as an input.
Type of change
Checklist: