Skip to content

Add cuda compatibility check for using grouped_mm#45001

Open
Sai-Suraj-27 wants to merge 5 commits intohuggingface:mainfrom
Sai-Suraj-27:add_cc_check
Open

Add cuda compatibility check for using grouped_mm#45001
Sai-Suraj-27 wants to merge 5 commits intohuggingface:mainfrom
Sai-Suraj-27:add_cc_check

Conversation

@Sai-Suraj-27
Copy link
Contributor

@Sai-Suraj-27 Sai-Suraj-27 commented Mar 25, 2026

What does this PR do?

For torch>=2.10.0, the minimum CUDA compute capability requirement for torch.nn.functional.grouped_mm is 8.0.
For torch==2.8.0 for torch._grouped_mm(), the minimum CUDA compute capability requirement is 9.0.

Code Agent Policy

The Transformers repo is currently being overwhelmed by a large number of PRs and issue comments written by
code agents. We are currently bottlenecked by our ability to review and respond to them. As a result,
we ask that new users do not submit pure code agent PRs at this time.
You may use code agents in drafting or to help you diagnose issues. We'd also ask autonomous "OpenClaw"-like agents
not to open any PRs or issues for the moment.

PRs that appear to be fully agent-written will probably be closed without review, and we may block users who do this
repeatedly or maliciously.

This is a rapidly-evolving situation that's causing significant shockwaves in the open-source community. As a result,
this policy is likely to be updated regularly in the near future. For more information, please read CONTRIBUTING.md.

  • I confirm that this is not a pure code agent PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@vasqu

@vasqu
Copy link
Contributor

vasqu commented Mar 25, 2026

cc @IlyasMoutawwakil something we discovered elsewhere, but it definitely makes sense to add these imo

will properly check it out tomorrow

@IlyasMoutawwakil
Copy link
Member

hi ! thanks for investigating this ! if i understand correctly these are not hard constraints (ie not breaking) they are just the conditions for the optimised triton/cutedsl paths right ? otherwise torch just uses the fallback path no ? or does it actually fail ?

@vasqu
Copy link
Contributor

vasqu commented Mar 25, 2026

It can actually fail on lower torch versions e.g. iirc 2.8 - while it is available, some SM computes won't be able to use it then, hence the extra guarding here.

We discovered those during some qwen moe tests, where @Sai-Suraj-27 ran on torch 2.8, see #44848

@IlyasMoutawwakil
Copy link
Member

i see, yeah originally we used to just raise an error if grouped_mm is requested and the version is less than 2.9
with the manual fallback that condition became obsolete because it is still a better alternative than eager.
i would just suggest we make the guards more explicit in this case because the current one :

        if hasattr(torch, "_grouped_mm"):
            return torch.cuda.get_device_capability(weight.device) >= (9, 0)

will trigger the manual fallback on torch 2.9 + A100 which is slower than torch._grouped_mm in the same setting.

@Sai-Suraj-27
Copy link
Contributor Author

i see, yeah originally we used to just raise an error if grouped_mm is requested and the version is less than 2.9 with the manual fallback that condition became obsolete because it is still a better alternative than eager. i would just suggest we make the guards more explicit in this case because the current one :

        if hasattr(torch, "_grouped_mm"):
            return torch.cuda.get_device_capability(weight.device) >= (9, 0)

will trigger the manual fallback on torch 2.9 + A100 which is slower than torch._grouped_mm in the same setting.

Thanks for the review @IlyasMoutawwakil @vasqu. Made it more explicit now.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but let's wait on @IlyasMoutawwakil to confirm that it's what he meant

# issue: https://github.com/pytorch/pytorch/issues/172440
return False

if weight.device.type == "cuda":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a small comment here for clarification

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants