Skip to content

Export torch_scaled_dot_product_attention#1404

Merged
dfalbel merged 4 commits intomainfrom
copilot/export-sdpa-function
Jan 27, 2026
Merged

Export torch_scaled_dot_product_attention#1404
dfalbel merged 4 commits intomainfrom
copilot/export-sdpa-function

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Jan 26, 2026

torch_scaled_dot_product_attention was implemented but not exported, forcing users to access it via get(..., envir = asNamespace("torch")). This prevented straightforward use of the fused CUDA kernels that provide 2-3x speedup over manual attention computation.

Changes

  • R/gen-namespace-docs.R: Added roxygen documentation with @export tag covering all 8 parameters, mathematical formula, and examples for basic usage, causal masking, and custom attention masks
  • NAMESPACE: Added export declaration
  • tests/testthat/test-gen-namespace.R: Added tests for basic usage, causal masking, attention masks, and dropout

Usage

library(torch)

query <- torch_randn(2, 8, 10, 64)  # (batch, heads, seq_len, dim)
key <- torch_randn(2, 8, 10, 64)
value <- torch_randn(2, 8, 10, 64)

# Now accessible directly
output <- torch_scaled_dot_product_attention(query, key, value)

# Supports causal masking for autoregressive models
output <- torch_scaled_dot_product_attention(query, key, value, is_causal = TRUE)
Original prompt

This section details on the original issue you should resolve

<issue_title>Export torch_scaled_dot_product_attention (SDPA)</issue_title>
<issue_description>## Summary

torch_scaled_dot_product_attention exists in the torch namespace but is not exported. Exporting it would provide significant performance benefits for transformer inference.

Current Workaround

# Have to access unexported function
sdpa <- get("torch_scaled_dot_product_attention", envir = asNamespace("torch"))
output <- sdpa(query, key, value, attn_mask = mask, is_causal = FALSE)

Why This Matters

SDPA uses fused CUDA kernels that are 2-3x faster than manual attention:

# Manual attention
scores <- torch_matmul(q, k$transpose(-2, -1)) / sqrt(head_dim)
attn_weights <- nnf_softmax(scores, dim = -1)
output <- torch_matmul(attn_weights, v)

# SDPA (fused kernel, 2.7x faster in benchmarks)
output <- torch_scaled_dot_product_attention(q, k, v)

For transformer models with 30+ layers, this adds up to meaningful speedups.

Proposed Change

Export torch_scaled_dot_product_attention in the package NAMESPACE so users can call it directly:

torch::torch_scaled_dot_product_attention(query, key, value)

Environment

  • torch version: 0.13.0
  • R version: 4.4.x
  • Discovered while porting Chatterbox TTS to native R torch</issue_description>

<agent_instructions>To export, make sure you add some docs and the #' @export tag. Also make sure the docs are not in the file that's autogenerated. Also please add a test !</agent_instructions>

Comments on the Issue (you are @copilot in this section)


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

Copilot AI and others added 2 commits January 26, 2026 14:02
Co-authored-by: dfalbel <4706822+dfalbel@users.noreply.github.com>
Co-authored-by: dfalbel <4706822+dfalbel@users.noreply.github.com>
Copilot AI changed the title [WIP] Export torch_scaled_dot_product_attention for performance benefits Export torch_scaled_dot_product_attention Jan 26, 2026
Copilot AI requested a review from dfalbel January 26, 2026 14:06
Copy link
Copy Markdown
Member

@dfalbel dfalbel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg

@dfalbel dfalbel marked this pull request as ready for review January 27, 2026 17:18
@dfalbel dfalbel merged commit 5eb52ab into main Jan 27, 2026
1 of 6 checks passed
@dfalbel dfalbel deleted the copilot/export-sdpa-function branch January 27, 2026 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Export torch_scaled_dot_product_attention (SDPA)

2 participants