Skip to content

Comments

[PyTorch] torch.compile support for permutation functions#2686

Open
pggPL wants to merge 8 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile
Open

[PyTorch] torch.compile support for permutation functions#2686
pggPL wants to merge 8 commits intoNVIDIA:mainfrom
pggPL:moe_torch_compile

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Feb 17, 2026

Description

This PR adds torch.compile(fullgraph=True) support for MoE permutation operations (moe_permute, moe_unpermute, moe_sort_chunks_by_index) by converting all torch.autograd.Function implementations to PyTorch custom operators using torch.library.custom_op.

Note that this PR does not add torch.compile support for QuantizedTensor as an input.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the moe_torch_compile branch from 41e22ef to 8159d26 Compare February 18, 2026 17:31
pre-commit-ci bot and others added 4 commits February 18, 2026 17:32
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL marked this pull request as ready for review February 19, 2026 15:45
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 19, 2026

Greptile Summary

Converted MoE permutation operations from torch.autograd.Function to torch.library.custom_op to enable torch.compile(fullgraph=True) support. The refactor maintains backward compatibility while adding compile support for moe_permute, moe_unpermute, and moe_sort_chunks_by_index.

Key changes:

  • Replaced 6 autograd.Function classes with custom ops registered under te_moe:: namespace
  • Added fake implementations for shape inference during compilation tracing
  • Implemented proper autograd registration with setup_context and backward wrappers
  • Added _quantized_tensor_passthrough_ops to prevent unwrapping FP8 tensors in __torch_dispatch__
  • Configured torch._dynamo.config.reorderable_logging_functions to allow warnings without graph breaks
  • Comprehensive test coverage with use_torch_compile parameter (limited to single config for efficiency)

Note: FP8 QuantizedTensor inputs are explicitly not supported under torch.compile (runtime error raised), as stated in PR description.

Confidence Score: 4/5

  • Safe to merge with minor consideration for global state thread-safety in highly concurrent scenarios
  • The implementation correctly converts autograd functions to custom ops with proper fake implementations and autograd registration. Tests provide good coverage. One minor concern: global workspace state could theoretically cause issues in concurrent compilation scenarios, though this is likely acceptable for the current use case.
  • No files require special attention - the implementation is well-structured and tested

Important Files Changed

Filename Overview
transformer_engine/pytorch/permutation.py Converted torch.autograd.Function implementations to torch.library.custom_op for torch.compile support; added fake implementations and proper autograd registration; includes global state management that may have thread-safety implications
tests/pytorch/test_permutation.py Added comprehensive torch.compile test coverage with use_torch_compile parameter; tests limited to single configuration to reduce test time; includes proper dynamo reset and functorch config
transformer_engine/pytorch/quantized_tensor.py Added passthrough mechanism for custom ops to handle quantized tensors without unwrapping in torch_dispatch

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[User calls moe_permute/moe_unpermute] --> B{torch.compile?}
    B -->|No| C[Direct custom op call]
    B -->|Yes| D[torch.ops.te_moe.* custom ops]
    
    D --> E[Forward Op]
    E --> F[register_fake for shape inference]
    E --> G[Actual implementation]
    
    G --> H{FP8 QuantizedTensor?}
    H -->|Yes| I[Passthrough in __torch_dispatch__]
    H -->|No| J[Normal tensor processing]
    
    I --> K[Handle FP8 internally]
    J --> K
    
    K --> L[setup_context saves state]
    L --> M[Backward Op via register_autograd]
    
    M --> N[Custom backward implementation]
    
    subgraph "Custom Ops"
        E
        F
        G
        L
        M
        N
    end
    
    subgraph "QuantizedTensor Handling"
        H
        I
        K
    end
Loading

Last reviewed commit: 981beb4

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits February 19, 2026 15:57
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Feb 19, 2026

/te-ci pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +33 to +34
_moe_permute_index_map_workspace = None
_moe_permute_index_map_max_expanded_token_num = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global mutable state may cause issues with concurrent execution or re-compilation

The workspace variables are shared module-level state that gets mutated. In torch.compile with multiple threads or parallel compilation, this could lead to race conditions. Consider thread-local storage or passing these as function arguments if possible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant