Skip to content

llama: add pshard runtime for plan switching and streamed weights#22692

Open
aukarande wants to merge 1 commit intoggml-org:masterfrom
aukarande:pshard/inference
Open

llama: add pshard runtime for plan switching and streamed weights#22692
aukarande wants to merge 1 commit intoggml-org:masterfrom
aukarande:pshard/inference

Conversation

@aukarande
Copy link
Copy Markdown

Overview

Pipelined sharding (pshard) is a CPU/GPU scheduling approach for VRAM-constrained inference. It combines prioritized VRAM placement, sub-layer sharding, CPU offload, and pipelined copy/compute. It has three stages: profiling(#22495), planning(#22691), and inference. First, we benchmark representative CPU and GPU kernels on the target machine. Then the planner uses those measurements to compare placement strategies within each token tier and writes the chosen schedules to a registry. At inference time, the runtime selects the smallest token tier that can cover the current n_tokens and uses that tier's schedule.

This is meant to work without manual tuning. The runtime can discover an optimal ubatch size, adapt across context, decode, and multi-user decode, and choose schedules based on the current VRAM budget, CPU thread count, and PCIe bandwidth, without requiring manual tensor overrides.

This work was proposed in a recent meeting with @ggerganov and @JohannesGaessler, and more details can be found in our paper: Efficient, VRAM-Constrained xLM Inference on Clients.

This PR adds the third phase: inference. It loads the registry written by the planner, picks the smallest tier that covers the current n_tokens, and applies that tier's schedule for decode.

The implementation in this PR leverages @JohannesGaessler's llama-fit-params (#16653) and @am17an's tensor-override prefetching work (#21067) to keep our changes compact compared to our original prototype for the paper which had more extensive changes.

Additional information

Usage

New runtime flag:

./build/bin/llama-cli \
    -m /path/to/model.gguf \
    -pshard \
	--no-mmap \
    -mva 12000 \
    -c 8192 \
    -fa on

Useful options and behavior:

  • -pshard: enable pipelined sharding in the common tools.
  • -mva, --max-vram-alloc <MiB>: runtime device-memory budget. 0 or omitted uses currently free VRAM.
  • --no-mmap: load all model weights into backend-managed host buffers, enabling copy/compute overlap
  • The runtime looks for <model>.gguf.tensor_overrides.pshard_registry, checks the fingerprint, and loads the best budget variant for the current budget.
  • During inference, the runtime selects the smallest tier that covers the current batch size and switches plans when needed.

Example log:

common_init_result: pshard enabled, probing and loading plan cache
pshard_registry_load: loaded 7 tier plans from /path/to/model.gguf.tensor_overrides.pshard_registry
llama_params_fit_pshard: plan: GPUONLY_ATTNPIN_FFNSTREAM, n_pinned=33/64, vram=12000 MiB, n_gpu_layers=65
load_tensors: preloaded 476 weights (5456.07 MiB) into 12000.00 MiB device buffer
pshard_pack_cache_region: layout: [weights 0..5456.07 | scratch 5456.07..11777.19 | cache 11777.19..12000.00 MiB]
pshard_warmup_plans: pre-computed 7 tiers in ... ms

Implementation Details

  • common_init_result creates a pshard registry, calls llama_params_fit_pshard(), and disables pshard if the plan is unavailable or marked unnecessary.
  • Model loading preloads pinned weights into one bounded device buffer. Sharded weights stay in host memory and are prefetched into compute scratch prior to their split runs.
  • The runtime buffer layout is [ pinned weights | compute scratch / streamed tensors | pinned KV/RS cache ]. Pinned cache is packed from the right side so plan switches can resize the scratch range.
  • A registry contains batch-size tiers (bs=1, bs=16, bs=512, ...). Decode and prefill can use different tiers because they have different scratch and cache pressure.
  • pshard_apply_plan() converts the selected tier into tensor and layer backend maps, applies external cache addresses, sets gallocr allocation ranges, and restores cached allocator/backend state when possible.
  • pshard_warmup_plans() reserves and saves gallocr state for viable tiers before generation starts. Later switches can reuse the saved state instead of doing a full reserve again.
  • pshard_switch_plan() downloads KV/RS that are being unpinned, applies the new tier, uploads newly pinned layers, and logs the download/upload counts.
  • llama_memory_pshard handles KV, SWA KV, recurrent state, and hybrid memory. It keeps the CPU copy authoritative for host operations and syncs GPU copies when a tier pins those layers.
  • Split callbacks do the streaming work around each scheduler split: upload current inputs, optionally prefetch the next split, and download/write back modified KV/RS after compute.

Backend Changes

  • Gallocr can now allocate from an externally owned buffer through ggml_gallocr_set_buffer() and can constrain allocations to a subrange through ggml_gallocr_set_alloc_range().
  • Gallocr exposes the chunk count and maximum chunk size so pshard can check that saved scratch layouts fit inside the bounded buffer.
  • Gallocr state can be saved and restored. The scheduler also saves node backend ids so a warmed tier can be reapplied without recomputing placement.
  • The scheduler has split callbacks for upload, prefetch, and writeback. A separate prefetch callback lets the next split start copying while the current split computes.
  • Async prefetch uses a copy backend and events when the backend reports copy_stream support. CUDA wires this capability for copy/compute overlap.
  • ggml_backend_sched_add_writeback() and GGML_TENSOR_FLAG_WRITEBACK mark KV/RS staging tensors that are written during a split and must stay alive until post-compute writeback.
  • Streamed weight copies are kept alive only until their last consumer, so later intermediates in the same split can reuse the memory.
  • Same-device backend redirection lets pshard use multiple logical scheduler backends on one physical GPU while still executing on the correct underlying backend.

Results

Tested on RTX 5080 - 16 GB (PCIe Gen5), Intel Xeon w5-3425 (12 cores), 512 GB system RAM (peak ~90 GB/s). The baseline is llama-fit-params (#16653).
Single user runs use 16K-256K context, OSL=200 and ISL = Context - OSL.

Summary (full result set in comments):

Scenario Workload TTFT speedup (x) TPS speedup (x) E2E speedup (x)
Large hybrid, single user Qwen3.5-397B-A17B-Q4 12.6-13.9 1.3-8.1 9.7-13.4
Dense, single user 3 dense models 3.3-6.9 1.4-10.0 2.1-6.2
Dense, batched same 3 dense models, 4-64 users, ISL/OSL=4K/200 4.1-9.1 1.6-4.9 2.2-6.4
MoE/hybrid, 16 GB, single user 3 MoE/hybrid models 4.5-17.8 0.9-1.0 4.4-11.1
MoE/hybrid, 8 GB, single user 3 MoE/hybrid models 4.4-22.8 0.9-6.7 4.3-12.4

Models used:

  • Large hybrid: Qwen3.5-397B-A17B-Q4
  • Dense: Qwen3-32B-Q4, Llama3-70B-Q4, Devstral-123B-Q4
  • MoE/hybrid: Qwen3.5-122B-A10B-Q4, Qwen3-Next-80B-A3B-Q4, GPT-OSS-120B-Q4

Qwen3.5-397B-A17B-Q4 (227 GB)

Single-user runs (OSL=200) with the rest of the context filled by the input prompt.

Context Base TTFT (s) pshard TTFT (s) Base TPS pshard TPS TTFT speedup (x) TPS speedup (x) E2E speedup (x)
16K 469.6 35.0 9.89 12.99 13.4 1.3 9.7
64K 1961.3 141.2 3.76 12.13 13.9 3.2 12.8
128K 4045.2 294.1 1.67 11.74 13.8 7.0 13.4
256K 8311.1 659.4 0.70 5.64 12.6 8.1 12.4

Requirements

  • I have read and agree with the contributing guidelines.
  • AI usage disclosure: Yes. I used AI for debugging and review, especially around scheduler/gallocr allocation behavior, plan-vs-runtime VRAM mismatches, tensor lifetime issues for streamed weights and KV/RS writeback, async prefetch/copy overlap, and MoE tail backend placement. All code changes were reviewed and tested by me.

@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented May 4, 2026

Hi @aukarande, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 3 open PRs.

  • Multiple backend changes in one PR: When adding support for a new model or feature, focus on CPU support only in the initial PR. Add support for other backends like CUDA in follow-up PRs. If you have a good reason to modify multiple backends in one PR, please explain it.

  • Large PR: Large changes require prior discussion (e.g. an issue or RFC) and maintainers may not be able to review this PR as-is. Consider splitting it into smaller, focused PRs.


Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@aukarande
Copy link
Copy Markdown
Author

Full benchmark tables used for the summary in the PR description:

Dense Models (Single-User)

Single-user runs (OSL=200) with the rest of the context filled by the input prompt.

Model Context Base TTFT (s) pshard TTFT (s) Base TPS pshard TPS TTFT speedup (x) TPS speedup (x) E2E speedup (x)
Qwen3-32B-Q4 16K 34.2 9.8 3.45 6.51 3.5 1.9 2.3
Qwen3-32B-Q4 64K 306.1 75.2 0.43 3.26 4.1 7.5 5.6
Qwen3-32B-Q4 128K 1044.2 266.4 0.18 1.45 3.9 8.0 5.3
Qwen3-32B-Q4 256K 3661.6 1114.5 0.08 0.59 3.3 7.4 4.3
Llama3-70B-Q4 16K 114.6 18.5 1.27 1.76 6.2 1.4 2.1
Llama3-70B-Q4 64K 646.8 135.8 0.26 1.72 4.8 6.6 5.6
Llama3-70B-Q4 128K 1806.5 425.1 0.13 0.95 4.2 7.5 5.3
Llama3-70B-Q4 256K 5648.6 1638.9 0.06 0.44 3.4 7.3 4.3
Devstral-123B-Q4 16K 234.0 33.9 0.67 1.32 6.9 2.0 2.9
Devstral-123B-Q4 64K 1189.3 223.1 0.14 1.02 5.3 7.2 6.2
Devstral-123B-Q4 128K 3165.1 725.8 0.07 0.69 4.4 9.5 5.8
Devstral-123B-Q4 256K 9438.2 2694.7 0.04 0.35 3.5 10.0 4.6

Dense Models (Batched)

Batched runs with ISL/OSL=4K/200 with 4 to 64 concurrent users.

Model Users Base TTFT (s) pshard TTFT (s) Base TPS pshard TPS TTFT speedup (x) TPS speedup (x) E2E speedup (x)
Qwen3-32B-Q4 4 32.0 7.8 12.08 24.50 4.1 2.0 2.4
Qwen3-32B-Q4 16 191.4 32.7 7.23 32.65 5.9 4.5 4.8
Qwen3-32B-Q4 32 488.5 66.4 7.59 35.46 7.4 4.7 5.4
Qwen3-32B-Q4 64 1106.9 137.2 7.88 34.46 8.1 4.4 5.4
Llama3-70B-Q4 4 114.1 19.7 4.41 6.90 5.8 1.6 2.2
Llama3-70B-Q4 16 497.8 83.6 3.87 10.80 6.0 2.8 3.5
Llama3-70B-Q4 32 1141.9 170.9 4.34 12.60 6.7 2.9 3.9
Llama3-70B-Q4 64 2358.1 346.2 5.42 13.13 6.8 2.4 3.6
Devstral-123B-Q4 4 223.8 28.0 2.38 4.32 8.0 1.8 2.6
Devstral-123B-Q4 16 961.7 113.1 2.17 9.92 8.5 4.6 5.6
Devstral-123B-Q4 32 2031.1 229.7 2.90 14.35 8.8 4.9 6.3
Devstral-123B-Q4 64 4205.0 463.5 3.96 18.34 9.1 4.6 6.4

MoE/Hybrid (16 GB VRAM Budget)

Single-user runs (OSL=200) with the rest of the context filled by the input prompt.

Model Context Base TTFT (s) pshard TTFT (s) Base TPS pshard TPS TTFT speedup (x) TPS speedup (x) E2E speedup (x)
Qwen3.5-122B-A10B-Q4 16K 144.1 11.1 25.87 25.95 13.0 1.0 8.1
Qwen3.5-122B-A10B-Q4 64K 613.9 47.9 24.32 24.34 12.8 1.0 11.1
Qwen3.5-122B-A10B-Q4 128K 1260.5 111.6 23.49 22.00 11.3 0.9 10.5
Qwen3.5-122B-A10B-Q4 256K 2713.2 288.9 20.17 20.43 9.4 1.0 9.1
Qwen3-Next-80B-A3B-Q4 16K 44.0 6.8 56.82 53.73 6.5 0.9 4.5
Qwen3-Next-80B-A3B-Q4 64K 193.0 29.9 49.19 47.53 6.5 1.0 5.8
Qwen3-Next-80B-A3B-Q4 128K 400.2 69.4 45.45 42.58 5.8 0.9 5.5
Qwen3-Next-80B-A3B-Q4 256K 925.8 206.2 37.08 33.14 4.5 0.9 4.4
GPT-OSS-120B-Q4 16K 65.9 3.7 34.69 32.21 17.8 0.9 7.2
GPT-OSS-120B-Q4 64K 284.2 20.1 31.62 29.39 14.1 0.9 10.8
GPT-OSS-120B-Q4 128K 674.8 58.8 27.79 25.71 11.5 0.9 10.2
GPT-OSS-120B-Q4 256K 1603.2 188.3 22.60 21.54 8.5 1.0 8.2

MoE/Hybrid (8 GB VRAM Budget)

Single-user runs (OSL=200) with the rest of the context filled by the input prompt.

Model Context Base TTFT (s) pshard TTFT (s) Base TPS pshard TPS TTFT speedup (x) TPS speedup (x) E2E speedup (x)
Qwen3.5-122B-A10B-Q4 16K 171.1 11.5 14.59 23.88 14.8 1.6 9.3
Qwen3.5-122B-A10B-Q4 64K 689.3 49.2 4.40 19.57 14.0 4.5 12.4
Qwen3.5-122B-A10B-Q4 128K 1446.9 115.3 1.80 10.64 12.6 5.9 11.6
Qwen3.5-122B-A10B-Q4 256K 3108.3 452.1 0.73 4.88 6.9 6.7 6.9
Qwen3-Next-80B-A3B-Q4 16K 57.2 7.1 49.55 46.95 8.0 0.9 5.4
Qwen3-Next-80B-A3B-Q4 64K 244.4 30.8 46.19 42.24 7.9 0.9 7.0
Qwen3-Next-80B-A3B-Q4 128K 524.7 70.7 39.74 39.56 7.4 1.0 7.0
Qwen3-Next-80B-A3B-Q4 256K 1241.6 285.0 2.31 12.97 4.4 5.6 4.4
GPT-OSS-120B-Q4 16K 88.0 3.9 30.75 29.12 22.8 0.9 8.8
GPT-OSS-120B-Q4 64K 374.3 23.3 27.78 26.74 16.1 1.0 12.4
GPT-OSS-120B-Q4 128K 817.6 89.8 5.21 24.00 9.1 4.6 8.7
GPT-OSS-120B-Q4 256K 1877.6 431.3 1.40 5.36 4.4 3.8 4.3

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs examples ggml changes relating to the ggml tensor library for machine learning labels May 4, 2026
@JohannesGaessler
Copy link
Copy Markdown
Contributor

JohannesGaessler commented May 4, 2026

In a quick test I first checked out the planning branch since to my understanding the planning code is not in this PR. I ran llama-pshard-plan-params for Qwen 3.5 35b a3b q8_0 on a single RTX 4090. After that I checked out the inference branch again and ran llama-completion with --ignore-eos -no-cnv -n 200 -pshard since there seems to be no llama-bench support. The prompt length is 346 tokens. I got 213.49 t/s for the prompt and 14.78 t/s for the generation. By comparison, the baseline on master with -fit on -fitt 1024 is 525.21 t/s for the prompt and 43.09 t/s for generation.

@JohannesGaessler
Copy link
Copy Markdown
Contributor

I forgot: the RTX 4090 is connected to my server with an EPYC 7742 CPU and 3200 "MHz" octa-channel RAM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants