Skip to content

Commit 78e7d1a

Browse files
author
niushengxiao
committed
feat: fp8kv support
1 parent a087340 commit 78e7d1a

19 files changed

Lines changed: 1829 additions & 1708 deletions

docs/CN/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Lightllm 整合了众多的开源方案的优点,包括但不限于 FasterTran
4949
:caption: 部署教程
5050

5151
DeepSeek R1 部署 <tutorial/deepseek_deployment>
52+
FP8 KV 量化与校准 <tutorial/fp8_kv_quantization>
5253
多级缓存部署 <tutorial/multi_level_cache_deployment>
5354
多模态部署 <tutorial/multimodal>
5455
奖励模型部署 <tutorial/reward_model>
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
.. _tutorial/fp8_kv_quantization_cn:
2+
3+
FP8 KV 量化与校准指南
4+
======================
5+
6+
本章节介绍 LightLLM 中 FP8 KV 推理的使用方式,包括:
7+
8+
- 使用校准文件进行推理(``fp8kv``)
9+
- FA3 与 FlashInfer 后端下的量化粒度差异
10+
- 常见报错与排查建议
11+
12+
功能概览
13+
--------
14+
15+
LightLLM 的 FP8 KV 推理需要准备好的校准文件(``kv_cache_calib.json``),
16+
并通过 ``--kv_quant_calibration_config_path`` 加载。
17+
你可以直接使用 ``test/advanced_config/`` 目录下已有的校准文件,
18+
也可以使用 `LightCompress <https://github.com/ModelTC/LightCompress>`_ 工具导出,或使用自有兼容文件。
19+
20+
后端与量化粒度
21+
--------------
22+
23+
当前行为如下:
24+
25+
- ``fa3``: 使用 ``per_head``(每个 head 独立 scale)
26+
- ``flashinfer``: 使用 ``per_tensor``(K/V 各一个标量 scale)
27+
28+
因此,校准文件与后端强相关:
29+
30+
- ``fa3`` 对应 ``per_head`` 校准文件,应配合 ``fa3`` 推理。
31+
- ``flashinfer`` 对应 ``per_tensor`` 校准文件,应配合 ``flashinfer`` 推理。
32+
33+
不建议混用不同后端的校准文件。
34+
35+
使用校准文件启动 FP8 推理
36+
-------------------------
37+
38+
推理模式示例(FA3):
39+
40+
.. code-block:: console
41+
42+
$ python -m lightllm.server.api_server \
43+
--model_dir /path/to/model \
44+
--llm_kv_type fp8kv \
45+
--llm_prefill_att_backend fa3 \
46+
--llm_decode_att_backend fa3 \
47+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
48+
49+
推理模式示例(FlashInfer):
50+
51+
.. code-block:: console
52+
53+
$ python -m lightllm.server.api_server \
54+
--model_dir /path/to/model \
55+
--llm_kv_type fp8kv \
56+
--llm_prefill_att_backend flashinfer \
57+
--llm_decode_att_backend flashinfer \
58+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
59+
60+
说明:
61+
62+
- ``fp8kv`` 模式必须提供 ``--kv_quant_calibration_config_path``。
63+
- 建议推理时的 attention backend 与校准文件要求保持一致。
64+
65+
校准文件格式
66+
------------
67+
68+
``kv_cache_calib.json`` 主要字段包括:
69+
70+
- ``quant_type``: ``per_head`` 或 ``per_tensor``
71+
- ``num_layers``: 层数
72+
- ``num_head``: 总 head 数
73+
- ``scales_shape``: scale 张量形状
74+
- ``scales``: 实际 scale 数值
75+
- ``qmin`` / ``qmax``: FP8 范围参数
76+
77+
加载校准文件时,会校验模型架构、层数、head 数及量化类型是否匹配。
78+
79+
多卡说明
80+
--------
81+
82+
在多卡(TP)场景下,系统会根据当前 rank 自动切分本地需要的 head 对应 scale。
83+
你仍然只需要提供一份全量 ``kv_cache_calib.json``。
84+
85+
常见问题
86+
--------
87+
88+
1. 启动时报错需要 ``--kv_quant_calibration_config_path``
89+
90+
说明你使用了 ``--llm_kv_type fp8kv`` 但未传入校准文件路径。
91+
92+
2. 报错 ``quant_type not match``
93+
94+
通常是后端与校准文件类型不一致。例如拿 ``per_head`` 文件去跑 ``flashinfer``。
95+
96+
3. 切换后端后效果异常
97+
98+
建议使用与目标后端匹配的校准文件,不要跨后端复用不兼容文件。

docs/EN/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Documentation List
4848
:caption: Deployment Tutorials
4949

5050
DeepSeek R1 Deployment <tutorial/deepseek_deployment>
51+
FP8 KV Quantization and Calibration <tutorial/fp8_kv_quantization>
5152
Multi-Level Cache Deployment <tutorial/multi_level_cache_deployment>
5253
Multimodal Deployment <tutorial/multimodal>
5354
Reward Model Deployment <tutorial/reward_model>
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
.. _tutorial/fp8_kv_quantization_en:
2+
3+
FP8 KV Quantization and Calibration Guide
4+
=========================================
5+
6+
This chapter describes FP8 KV inference in LightLLM, including:
7+
8+
- Running inference with calibration data (``fp8kv``)
9+
- Quantization granularity differences between FA3 and FlashInfer
10+
- Common errors and troubleshooting
11+
12+
Overview
13+
--------
14+
15+
LightLLM FP8 KV inference requires a prepared calibration file (``kv_cache_calib.json``),
16+
which is loaded by ``--kv_quant_calibration_config_path``.
17+
You can use calibration files provided in ``test/advanced_config/``,
18+
export one with `LightCompress <https://github.com/ModelTC/LightCompress>`_, or use your own compatible file.
19+
20+
Backend and Quantization Granularity
21+
------------------------------------
22+
23+
Current behavior:
24+
25+
- ``fa3``: ``per_head`` scales (independent scale per head)
26+
- ``flashinfer``: ``per_tensor`` scales (one scalar for K and one scalar for V)
27+
28+
Calibration files are backend-dependent:
29+
30+
- ``per_head`` files for ``fa3`` should be used with ``fa3`` inference.
31+
- ``per_tensor`` files for ``flashinfer`` should be used with ``flashinfer`` inference.
32+
33+
Avoid mixing calibration files across different backends.
34+
35+
Start FP8 Inference with Calibration
36+
------------------------------------
37+
38+
Inference mode example (FA3):
39+
40+
.. code-block:: console
41+
42+
$ python -m lightllm.server.api_server \
43+
--model_dir /path/to/model \
44+
--llm_kv_type fp8kv \
45+
--llm_prefill_att_backend fa3 \
46+
--llm_decode_att_backend fa3 \
47+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
48+
49+
Inference mode example (FlashInfer):
50+
51+
.. code-block:: console
52+
53+
$ python -m lightllm.server.api_server \
54+
--model_dir /path/to/model \
55+
--llm_kv_type fp8kv \
56+
--llm_prefill_att_backend flashinfer \
57+
--llm_decode_att_backend flashinfer \
58+
--kv_quant_calibration_config_path /path/to/kv_cache_calib.json
59+
60+
Notes:
61+
62+
- ``fp8kv`` requires ``--kv_quant_calibration_config_path``.
63+
- Keep the inference backend consistent with the backend expected by the calibration file.
64+
65+
Calibration File Schema
66+
-----------------------
67+
68+
Key fields in ``kv_cache_calib.json``:
69+
70+
- ``quant_type``: ``per_head`` or ``per_tensor``
71+
- ``num_layers``: number of layers
72+
- ``num_head``: total number of heads
73+
- ``scales_shape``: shape of the scale tensor
74+
- ``scales``: actual scale values
75+
- ``qmin`` / ``qmax``: FP8 numeric range parameters
76+
77+
At load time, LightLLM validates architecture, layer count, head count, and quantization type.
78+
79+
Multi-GPU Note
80+
--------------
81+
82+
In multi-GPU (TP) setups, LightLLM slices the global scales to local rank heads automatically.
83+
You only need to provide one full ``kv_cache_calib.json`` file.
84+
85+
Common Issues
86+
-------------
87+
88+
1. Error says ``--kv_quant_calibration_config_path`` is required
89+
90+
You are using ``--llm_kv_type fp8kv`` without a calibration file path.
91+
92+
2. ``quant_type not match`` error
93+
94+
Usually caused by backend/file mismatch (for example, using a ``per_head`` file with ``flashinfer``).
95+
96+
3. Abnormal quality after backend switch
97+
98+
Use a calibration file that matches the target backend instead of reusing an incompatible file.

lightllm/common/basemodel/attention/create_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
# "fa3": Fp8Fa3AttBackend,
3636
# "flashinfer": Fp8FlashInferAttBackend,
3737
},
38+
"fp8kv": {
39+
"fa3": Fp8Fa3AttBackend,
40+
"flashinfer": Fp8FlashInferAttBackend,
41+
},
3842
}
3943

4044
mla_data_type_to_backend = {

lightllm/common/basemodel/attention/fa3/fp8.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,21 @@ def _fp8_prefill_att(
8989
) -> torch.Tensor:
9090
self.backend: Fp8Fa3AttBackend = self.backend # for typing
9191

92+
q_head_num = q.shape[1]
93+
q_head_dim = q.shape[2]
94+
k_head_num = k.shape[1]
9295
q, q_scale = q_per_head_fp8_quant(
93-
q,
96+
q.reshape(q.shape[0], k_head_num, -1),
9497
self.infer_state.b_seq_len,
9598
self.cu_seqlens_q,
96-
self.mid_token_batch_ids,
99+
token_batch_ids=self.mid_token_batch_ids,
97100
)
98-
k_head_num = k.shape[1]
99101
k_head_dim = k.shape[2]
100102
cache_k = k.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
101103
cache_v = v.view(-1, 1, k_head_num, k_head_dim).view(torch.float8_e4m3fn)
102104
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
103105
o = flash_attn_with_kvcache(
104-
q=q,
106+
q=q.reshape(-1, q_head_num, q_head_dim),
105107
k_cache=cache_k,
106108
v_cache=cache_v,
107109
page_table=self.page_table,
@@ -200,9 +202,11 @@ def _fp8_decode_att(
200202
layer_index = self.backend._find_layer_index(k=cache_k, v=cache_v, att_state=self)
201203

202204
q_head_num = q.shape[1]
203-
q, q_scale = scaled_fp8_quant(q.view(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
205+
if scaled_fp8_quant is None:
206+
raise ImportError("scaled_fp8_quant is unavailable. Please install vllm to enable FP8 decode attention.")
207+
q, q_scale = scaled_fp8_quant(q.reshape(q.shape[0] * k_head_num, -1), use_per_token_if_dynamic=True)
204208
o = flash_attn_with_kvcache(
205-
q=q.view(-1, q_head_num, k_head_dim),
209+
q=q.reshape(-1, q_head_num, k_head_dim),
206210
k_cache=cache_k,
207211
v_cache=cache_v,
208212
page_table=self.page_table,

lightllm/common/basemodel/attention/flashinfer/fp8.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def create_att_decode_state(self, infer_state) -> "Fp8FlashInferDecodeAttState":
2020

2121
@dataclasses.dataclass
2222
class Fp8FlashInferPrefillAttState(FlashInferPrefillAttState):
23-
offline_scales: torch.Tensor = None
23+
offline_scales: list = None
2424

2525
def init_state(self):
2626
super().init_state()
@@ -68,7 +68,7 @@ def _fp8_prefill_att(
6868

6969
@dataclasses.dataclass
7070
class Fp8FlashInferDecodeAttState(FlashInferDecodeAttState):
71-
offline_scales: torch.Tensor = None
71+
offline_scales: list = None
7272

7373
def init_state(self):
7474
super().init_state()

lightllm/common/kv_cache_mem_manager/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager
22
from .calibration_fp8kv_mem_manager import CalibrationFP8KVMemoryManager
3-
from .export_calibration_mem_manager import ExportCalibrationMemoryManager
43
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
54
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
65
from .deepseek2_mem_manager import Deepseek2MemoryManager
@@ -10,7 +9,6 @@
109
"MemoryManager",
1110
"ReadOnlyStaticsMemoryManager",
1211
"CalibrationFP8KVMemoryManager",
13-
"ExportCalibrationMemoryManager",
1412
"PPLINT4KVMemoryManager",
1513
"PPLINT8KVMemoryManager",
1614
"Deepseek2MemoryManager",
Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,28 @@
1+
import torch
2+
from typing import Tuple, Any
13
from .offline_fp8_quant_mem_manager import OfflineFP8QuantMemManager
24

35

46
class CalibrationFP8KVMemoryManager(OfflineFP8QuantMemManager):
57
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
6-
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction, is_export_mode=False)
8+
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy, mem_fraction)
9+
10+
def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor):
11+
"""
12+
推理模式:使用预计算的FP8 scale将kv量化后拷贝到kv_buffer中
13+
"""
14+
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv_fp8 import destindex_copy_kv_fp8
15+
16+
scales = self.scales
17+
destindex_copy_kv_fp8(
18+
kv,
19+
mem_index,
20+
scales[layer_index] if scales is not None else None,
21+
self.kv_buffer[layer_index].view(torch.float8_e4m3fn),
22+
)
23+
return
24+
25+
def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]:
26+
k = self.kv_buffer[layer_index][:, : self.head_num, :]
27+
v = self.kv_buffer[layer_index][:, self.head_num :, :]
28+
return k, v

lightllm/common/kv_cache_mem_manager/export_calibration_mem_manager.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

0 commit comments

Comments
 (0)