Feat/backward mul mat#22704
Conversation
Adds a new ggml op GGML_OP_OUT_PROD_ID — scattered outer-product for
the MUL_MAT_ID backward pass (MoE LoRA training).
result[:,:,e] += sum_{(i,t): ids[i,t]==e} a[:,i,t] ⊗ b[:,i,t]
This computes the gradient w.r.t. expert weight matrices (src0) of
MUL_MAT_ID, accumulating per-expert gradients from dispatched tokens.
- ggml.h: enum GGML_OP_OUT_PROD_ID, ggml_out_prod_id() declaration
- ggml.c: op name/symbol, GGML_OP_COUNT 96→97, ggml_out_prod_id() impl
Enables gradient computation through MUL_MAT_ID (expert-dispatched matrix multiply used in MoE LoRA training). Backward w.r.t. activations (src1): MUL_MAT_ID with transposed experts. Backward w.r.t. expert weights (src0, F32 only): OUT_PROD_ID scattered outer-product accumulates per-expert gradients from dispatched tokens. Quantized src0 is treated as frozen (stop-gradient). Also: - ggml_get_rows_back: support 3D output (needed for 3D grad accumulation) - sigmoid backward: d/dx = sigmoid(x)*(1-sigmoid(x)) = tensor - tensor^2 - ggml_build_backward_expand: ignore integer/frozen srcs for MUL_MAT_ID, SET_ROWS, SSM_CONV, SSM_SCAN, FLASH_ATTN_EXT; replace ASSERT on unsupported inplace ops with a warning+skip to avoid crashes
Adds ggml_cuda_out_prod_id() — the CUDA kernel for OUT_PROD_ID, which computes scattered outer-products for the MUL_MAT_ID backward pass (gradient w.r.t. MoE expert weight matrices). For each expert e the kernel gathers the dispatched token vectors into contiguous GPU buffers and calls cublasSgemm (beta=1 accumulation) to compute grad_weight[:,:,e] += sum_t a[:,t] ⊗ b[:,t]. Also extends ggml_cuda_out_prod() to accept quantized src0 by dequantizing to a temporary F32 buffer before the Sgemm call, allowing it to serve the transposed-expert gradient path in the MUL_MAT_ID backward on CPUs that use mixed quantized+F32 LoRA adapters.
Extends the AdamW optimizer with an element-wise gradient clipping parameter (gclip, pars[7]). When gclip > 0 each gradient element is clamped to [-gclip, gclip] before the first/second moment update, preventing outlier gradients from corrupting the momentum state. Set gclip = 0 (default) to disable — existing behavior is preserved. Also adds two training-infrastructure improvements to ggml-opt: - Moment tensors (m, v) are now allocated on the same backend buffer type as their param tensor, avoiding cross-device mismatches when LoRA adapters span CPU and GPU (partial offload scenarios). - Non-static graphs: gradient accumulators are explicitly zeroed at the start of each accumulation cycle, fixing stale-gradient carry-over that occurred because ggml_graph_reset is not called between evals. And gradient checkpointing support: ggml_opt_params.grad_checkpoint_interval marks every Nth forward node as OUTPUT so the allocator keeps activations alive through the backward pass, trading compute for VRAM reduction.
Reads pars[7] (gclip) from the AdamW parameter tensor and clamps each gradient element to [-gclip, gclip] before the moment update. gclip == 0 disables clipping (preserves existing behavior).
Reads params[8] (gclip) from the AdamW parameter buffer in the GLSL shader and clamps each gradient element to [-gclip, gclip] before the moment update. Also extends the params array declaration from [7] to [8] and updates the matching assert in ggml-vulkan.cpp. gclip == 0 disables clipping (preserves existing behavior).
…feat/backward-mul-mat-id
|
Hi @srossitto79, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
Overview
This PR adds gradient computation through
MUL_MAT_ID(MoE expert-dispatched matrix multiply), enablingfine-tuning of Mixture-of-Experts models. It also extends the AdamW optimizer with per-element gradient
clipping and fixes several correctness issues in the backward graph builder.
Additional information
This is the foundational PR for MoE training. A follow-up PR will add the full QLoRA training example
on top of these primitives.
Requirements