Skip to content

Feat/backward mul mat#22704

Draft
srossitto79 wants to merge 10 commits intoggml-org:masterfrom
srossitto79:feat/backward-mul-mat-id
Draft

Feat/backward mul mat#22704
srossitto79 wants to merge 10 commits intoggml-org:masterfrom
srossitto79:feat/backward-mul-mat-id

Conversation

@srossitto79
Copy link
Copy Markdown

Overview

This PR adds gradient computation through MUL_MAT_ID (MoE expert-dispatched matrix multiply), enabling
fine-tuning of Mixture-of-Experts models. It also extends the AdamW optimizer with per-element gradient
clipping and fixes several correctness issues in the backward graph builder.

  • added new op `GGML_OP_OUT_PROD_ID
  • backward pass for MUL_MAT_ID
  • element wise gradient clipping (gclip)
  • backward graph builder fixes (GET_ROWS_BACK extended to 3d output, SIGMOID backward was missing, replaced GGML_ASSERT with skip in ggml_build_backward_expand, added ignore_src rules for MUL_MAT_ID for quantized experts frozen)

Additional information

This is the foundational PR for MoE training. A follow-up PR will add the full QLoRA training example
on top of these primitives.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure: YES, AI assisted with formatting, diff organization, and applying consistent changes across backends

srossitto79 added 10 commits May 4, 2026 09:15
Adds a new ggml op GGML_OP_OUT_PROD_ID — scattered outer-product for
the MUL_MAT_ID backward pass (MoE LoRA training).

  result[:,:,e] += sum_{(i,t): ids[i,t]==e} a[:,i,t] ⊗ b[:,i,t]

This computes the gradient w.r.t. expert weight matrices (src0) of
MUL_MAT_ID, accumulating per-expert gradients from dispatched tokens.

- ggml.h: enum GGML_OP_OUT_PROD_ID, ggml_out_prod_id() declaration
- ggml.c: op name/symbol, GGML_OP_COUNT 96→97, ggml_out_prod_id() impl
Enables gradient computation through MUL_MAT_ID (expert-dispatched
matrix multiply used in MoE LoRA training).

Backward w.r.t. activations (src1): MUL_MAT_ID with transposed experts.
Backward w.r.t. expert weights (src0, F32 only): OUT_PROD_ID scattered
outer-product accumulates per-expert gradients from dispatched tokens.
Quantized src0 is treated as frozen (stop-gradient).

Also:
- ggml_get_rows_back: support 3D output (needed for 3D grad accumulation)
- sigmoid backward: d/dx = sigmoid(x)*(1-sigmoid(x)) = tensor - tensor^2
- ggml_build_backward_expand: ignore integer/frozen srcs for MUL_MAT_ID,
  SET_ROWS, SSM_CONV, SSM_SCAN, FLASH_ATTN_EXT; replace ASSERT on
  unsupported inplace ops with a warning+skip to avoid crashes
Adds ggml_cuda_out_prod_id() — the CUDA kernel for OUT_PROD_ID, which
computes scattered outer-products for the MUL_MAT_ID backward pass
(gradient w.r.t. MoE expert weight matrices).

For each expert e the kernel gathers the dispatched token vectors into
contiguous GPU buffers and calls cublasSgemm (beta=1 accumulation) to
compute grad_weight[:,:,e] += sum_t a[:,t] ⊗ b[:,t].

Also extends ggml_cuda_out_prod() to accept quantized src0 by
dequantizing to a temporary F32 buffer before the Sgemm call, allowing
it to serve the transposed-expert gradient path in the MUL_MAT_ID
backward on CPUs that use mixed quantized+F32 LoRA adapters.
Extends the AdamW optimizer with an element-wise gradient clipping
parameter (gclip, pars[7]).  When gclip > 0 each gradient element is
clamped to [-gclip, gclip] before the first/second moment update,
preventing outlier gradients from corrupting the momentum state.

Set gclip = 0 (default) to disable — existing behavior is preserved.

Also adds two training-infrastructure improvements to ggml-opt:
- Moment tensors (m, v) are now allocated on the same backend buffer
  type as their param tensor, avoiding cross-device mismatches when LoRA
  adapters span CPU and GPU (partial offload scenarios).
- Non-static graphs: gradient accumulators are explicitly zeroed at the
  start of each accumulation cycle, fixing stale-gradient carry-over
  that occurred because ggml_graph_reset is not called between evals.

And gradient checkpointing support: ggml_opt_params.grad_checkpoint_interval
marks every Nth forward node as OUTPUT so the allocator keeps activations
alive through the backward pass, trading compute for VRAM reduction.
Reads pars[7] (gclip) from the AdamW parameter tensor and clamps each
gradient element to [-gclip, gclip] before the moment update.
gclip == 0 disables clipping (preserves existing behavior).
Reads params[8] (gclip) from the AdamW parameter buffer in the GLSL
shader and clamps each gradient element to [-gclip, gclip] before the
moment update.  Also extends the params array declaration from [7] to
[8] and updates the matching assert in ggml-vulkan.cpp.

gclip == 0 disables clipping (preserves existing behavior).
@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented May 5, 2026

Hi @srossitto79, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 2 open PRs.

  • Multiple backend changes in one PR: When adding support for a new model or feature, focus on CPU support only in the initial PR. Add support for other backends like CUDA in follow-up PRs. If you have a good reason to modify multiple backends in one PR, please explain it.

  • AI-generated content: This project does not accept PRs, descriptions or commit messages that are fully or predominantly AI-generated. If you have used AI to assist you in writing code, please make sure to disclose that explicitly.


Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant