metal : promote mul_mv/mul_mm batch divisors to function constants#22711
Open
guyfischman wants to merge 1 commit intoggml-org:masterfrom
Open
metal : promote mul_mv/mul_mm batch divisors to function constants#22711guyfischman wants to merge 1 commit intoggml-org:masterfrom
guyfischman wants to merge 1 commit intoggml-org:masterfrom
Conversation
|
Hi @guyfischman, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
Author
|
For reference I (using AI) made this PoC to see when literal vs FC wins and by how much over kernel arg/uniform buffers - https://github.com/SovereignSoft/agx-idiv-demo |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Overview
The batch-dimension fold i12 = im % args.ne12; i13 = im / args.ne12; and the GQA divisors i12 / args.r2, i13 / args.r3 execute on every dispatch of every mat-vec / mat-mat kernel. On Apple AGX, runtime int divmod is a ~700-cycle software path, while compile-time-known divisors fold to either a no-op (ne12=1), a shift (power-of-two divisor), or a magic-multiply (any other constant).
This PR promotes ne12, r2, r3 (and ne13 for mul_mm) to function constants so they are baked into the PSO at compile time:
FC_mul_mm_ne12/ne13/r2/r3 - bound by mul_mm pipeline from op shape
FC_mul_mv_ne12/r2/r3 - bound by mul_mv and mul_mv_ext pipelines from op shape; mul_mv_id binds to 1 since kernel_mul_mv_id wraps the impls with args0 = { ne12=1, r2=1, r3=1 }
The pipeline-cache key now includes ne12/r2/r3 so a Gemma-2 GQA PSO is not reused for a TinyLlama (ne12=1) op or vice versa.
This is the metal analog of #22650.
Additional information
Substituted args.ne12/r2/r3 with the FCs in 23 mul_mv impl sites, both mul_mv_ext impls, and both kernel_mul_mm variants (tensor and non-tensor branches).
Other kernels using args.ne12 (kernel_bin_fuse, kernel_soft_max, kernel_set_rows) are built by other pipelines that do not bind these FCs so they're unchanged.
The four FC values are bound as int16_t to match existing convention (nsg, nxpsg). Pipeline-cache call sites assert ne12/ne13/r2/r3 ≤ INT16_MAX before the cast to guard against silent wrap on edge cases like very large prompt batches. No-op for typical inference.
Correctness
./build/bin/test-backend-ops -b Metal
3/3 backends passed
Also verified bit-identicalness with llama-perplexity on all below models.
Performance
Apple M4 Pro, llama-bench -p 512 -n 128 -r 20, tg128 in t/s, master → patched:
TinyLlama Q4_0: 239.05 ± 1.73 → 246.57 ± 0.93 (+3.15%)
Llama 3.2 1B Q4_0: 227.58 ± 1.84 → 230.96 ± 2.65 (+1.49%, within noise)
Gemma 3 1B Q4_K_M (GQA): 164.67 ± 4.59 → 173.52 ± 0.75 (+5.37%)
Gemma 2 2B Q4_K_M (GQA): 103.41 ± 0.39 → 106.66 ± 0.24 (+3.14%)
Mistral 7B Q4_0 (GQA): 52.10 ± 0.08 → 52.73 ± 0.07 (+1.21%)
pp512 is unchanged across all five models (within ±1% noise). mul_mv is not on the prompt path; mul_mm shapes that hit the new FCs are too few to surface above bench noise.
Requirements