Skip to content

tools: add llama-pshard-plan-params for token-tiered placement planning#22691

Open
aukarande wants to merge 1 commit intoggml-org:masterfrom
aukarande:pshard/planning
Open

tools: add llama-pshard-plan-params for token-tiered placement planning#22691
aukarande wants to merge 1 commit intoggml-org:masterfrom
aukarande:pshard/planning

Conversation

@aukarande
Copy link
Copy Markdown

@aukarande aukarande commented May 4, 2026

Overview

Pipelined sharding (pshard) is a CPU/GPU scheduling approach for VRAM-constrained inference. It combines prioritized VRAM placement, sub-layer sharding, CPU offload, and pipelined copy/compute. It has three stages: profiling (#22495), planning, and inference(#22692). First, we benchmark representative CPU and GPU kernels on the target machine. Then the planner uses those measurements to compare placement strategies within each token tier and writes the chosen schedules to a registry. At inference time, the runtime selects the smallest token tier that can cover the current n_tokens and uses that tier's schedule.

This is meant to work without manual tuning. The runtime can discover an optimal ubatch size, adapt across context, decode, and multi-user decode, and choose schedules based on the current VRAM budget, CPU thread count, and PCIe bandwidth, without requiring manual tensor overrides. The full set of TTFT, TPS, and end-to-end results for a variety of models is included in the inference PR (#22692).

This work was proposed in a recent meeting with @ggerganov and @JohannesGaessler, and more details can be found in our paper: Efficient, VRAM-Constrained xLM Inference on Clients.

This PR adds the second phase: planning. Given a model, target VRAM budget, token tiers, and roofline profiles from llama-profiler-{cpu,gpu} (#22495), the planner predicts TPS for each candidate placement and writes the selected plans to a plain-text registry next to the model:

<model>.gguf.tensor_overrides.pshard_registry

The implementation in this PR builds on @JohannesGaessler's llama-fit-params (#16653) and @am17an's tensor-override prefetching work (#21067). It makes things elegant and vendor-agnostic, compared to our original from-scratch implementation for the paper.

Additional information

User Interface

New tool:

./build/bin/llama-pshard-plan-params \
    -m /path/to/model.gguf \
    -c 8192 \
    -fa on \
    -mva 12000

Useful options:

  • -mva, --max-vram-alloc <MiB>: planning budget. 0 or omitted uses currently free VRAM.
  • --pshard-tier-max <N>: largest batch-size tier to probe. The default is min(max(n_batch, 16384), n_ctx).
  • PSHARD_STRATEGY: force one strategy (STATIC_FITPARAMS_DENSEPRIO_MOEONLY, STATIC_ATTNPRIO_ALLMODELS, DYNAMIC_FFNCPU_ATTNSTREAM, GPUONLY_LAYERPIN_LAYERSTREAM, or GPUONLY_ATTNPIN_FFNSTREAM).
  • PSHARD_CPU_PROFILE / PSHARD_GPU_PROFILE: override the profiler inputs used by the throughput predictor.

The registry is plain text. Each tier stores the chosen strategy, predicted TPS, VRAM requirement, and an extended tensor override list:

[fingerprint=0x9041bb5b253cf89f]
# n_ctx=8192 n_seq_max=1 n_threads=8 fa=on type_k=1 type_v=1 strategy=auto

[variant budget=12000 cache_ubatch=8192]
[tier 0 bs=1]
strategy=DYNAMIC_FFNCPU_ATTNSTREAM n_pinned=54 n_attn_pinned=0 overflow=NONE tps=12.85 vram=11975.7 output_on_gpu=0 pin_from_back=0
ot=^output=CUDA_Host:3,^token_embd=CUDA_Host:3,...

[tier 6 bs=8192]
strategy=GPUONLY_ATTNPIN_FFNSTREAM n_pinned=33 n_attn_pinned=0 overflow=NONE tps=2148.21 vram=11937.2 output_on_gpu=0 pin_from_back=0
ot=...

backend_id in ot= entries lets the planner target a specific scheduler backend instead of only a buffer type. In the current pshard layout:

  • 0: GPU pinned compute
  • 1, 2: shard compute lanes used for overlap
  • 3: CPU

Implementation Details

  • The main entry point is llama_params_fit_pshard(). It uses no-alloc probe models and contexts to estimate model, cache, and compute memory without loading weights for inference.
  • The planner creates batch-size tiers (bs=1, bs=16, bs=512, ...) and picks the best placement for each tier.
  • Strategies: In the strategy names, ATTN refers to the attention/dense side of the layer, as opposed to FFN/MoE weights.
    • Static schedules: GPU-resident tensors run on GPU, CPU-resident tensors run on CPU, with no streamed GPU execution for host-resident weights.

      • STATIC_FITPARAMS_DENSEPRIO_MOEONLY: static llama_params_fit placement with front-to-back fill mode enabled. This is the baseline schedule: MoE models first keep dense-only parts on GPU with expert tensors on CPU, then convert dense-only layers to full layers as budget allows. Dense models get full layers.
      • STATIC_ATTNPRIO_ALLMODELS: static attention-priority placement. It pins the attention/dense side across as many layers as fit, then uses the remaining budget to pin full layers. FFN/MoE that does not fit remains on CPU. Unlike llama_params_fit, this attention-priority placement applies to dense models too.
    • Dynamic schedules: split the layer between CPU and GPU execution. Some host-resident tensors are streamed to GPU scratch for execution, while other parts of the layer remain on CPU.

      • DYNAMIC_FFNCPU_ATTNSTREAM: pin as many full layers as fit, keep FFN/MoE on CPU in the remaining layers, and stream the attention/dense side for GPU execution.
    • GPU-only schedules: execute repeating-layer compute on GPU. Weights that do not fit in VRAM stay resident in host memory and are streamed to GPU scratch before use.

      • GPUONLY_LAYERPIN_LAYERSTREAM: pin as many full layers as fit, then stream the remaining layers for GPU execution.
      • GPUONLY_ATTNPIN_FFNSTREAM: pin the attention/dense side for all layers, pin as many complete layers as the remaining budget allows, and stream FFN/MoE weights for GPU execution.
  • For each strategy, the search does a binary search on n_pinned under the VRAM budget, then tries one-layer fractional overflow in the same priority order as llama_params_fit: attention, up, gate, then MoE tensors.
  • If profiler data is available, llama_benchmark_predictor ranks viable plans by predicted TPS. Without profiles, tps remains unset and the planner picks the viable plan with the most pinned layers under the VRAM budget, using measured VRAM as the final tie-breaker.
  • The largest tier is probed first. If baseline static placement (llama_params_fit) already fits there, the registry records pshard_disabled=1 baseline_vram=<MiB> so the runtime can skip pshard setup when that baseline fits the current budget.
  • Registry sections are keyed by a fingerprint over plan-compatible inputs: n_ctx, n_seq_max, thread count, Flash Attention mode, KV cache types, GGUF file size, and forced PSHARD_STRATEGY. Budget and cache sizing are stored as [variant budget=... cache_ubatch=...], so multiple budget/cache variants can share one model/context section.
  • Pshard planning is disabled when the user has fixed the placement (SPLIT_MODE_TENSOR, SPLIT_MODE_ROW, n_gpu_layers, tensor_split, tensor_buft_overrides) or disabled KV/KQV offload.
  • Tensor overrides carry a backend_id, and llama_model records tensor/layer backend maps so probe graphs can use the selected placement.
  • llama_memory_pshard and llama_memory_pipe_shard_i let probe contexts size pinned versus streamed KV/RS memory.
  • The model loader can duplicate tied tensors so token_embd and output.weight may receive different placements when pshard needs them on different backends.

Backend Changes

  • ggml_backend_sched_get_split_info() exposes per-split input, activation, and writeback bytes for the planner cost model.
  • The scheduler emits pshard-only lifetime nodes around each split so gallocr reserves memory for streamed weights and writeback buffers at the right time.
  • GGML_TENSOR_FLAG_WRITEBACK marks KV/RS staging tensors that must stay alive until post-compute writeback.
  • copy_stream is added as a backend capability. The scheduler only creates copy backends and prefetch reservations for devices that report a separate copy stream.

Requirements

  • I have read and agree with the contributing guidelines.
  • AI usage disclosure: Yes, I used AI for cleanup work and while debugging the planner search heuristics. The strategy set, per-tier search structure, fit-params integration, and cost model is ours. I have manually reviewed the code and tested the planner on Llama, Qwen, Kimi, and GLM-family models across several VRAM budgets.

@aukarande aukarande requested review from a team, CISC and ggerganov as code owners May 4, 2026 20:17
@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented May 4, 2026

Hi @aukarande, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 2 open PRs.

  • Multiple backend changes in one PR: When adding support for a new model or feature, focus on CPU support only in the initial PR. Add support for other backends like CUDA in follow-up PRs. If you have a good reason to modify multiple backends in one PR, please explain it.

  • Large PR: Large changes require prior discussion (e.g. an issue or RFC) and maintainers may not be able to review this PR as-is. Consider splitting it into smaller, focused PRs.


Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs examples ggml changes relating to the ggml tensor library for machine learning labels May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant