-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathflashinfer_mla_decode.py
More file actions
227 lines (196 loc) · 6.46 KB
/
flashinfer_mla_decode.py
File metadata and controls
227 lines (196 loc) · 6.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# Adapt from https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py
import argparse
import os
import sys
import pandas as pd
import torch
import torch.utils.benchmark as benchmark
from flashinfer import BatchMLAPagedAttentionWrapper
parent_dir = os.path.join(os.path.dirname(__file__), "..")
sys.path.append(os.path.abspath(parent_dir))
from config.model_config import ModelConfig # noqa E402
from flops.flops import get_mla_absorb_gflops # noqa E402
def benchmark_forward(
fn,
*inputs,
repeats=10,
amp=False,
amp_dtype=torch.float16,
**kwinputs,
):
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
t = benchmark.Timer(
stmt="fn_amp(*inputs, **kwinputs)",
globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
m = t.timeit(repeats)
return t, m
def time_fwd(func, *args, **kwargs):
time_f = benchmark_forward(func, *args, **kwargs)
return time_f[1].mean * 1e6
def decode_attention_flashinfer():
flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
)
workspace_buffer = torch.empty(
flashinfer_workspace_size, dtype=torch.int8, device="cuda"
)
flashinfer_decode_wrapper = BatchMLAPagedAttentionWrapper(
workspace_buffer, backend="auto"
)
class FlashinferAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q_nope,
q_pe,
ckv,
kpe,
batch_size,
kv_len,
num_local_heads,
kv_lora_rank,
qk_rope_head_dim,
warmup=10,
):
q_indptr = torch.arange(0, batch_size + 1).to(0).int()
total_tokens = batch_size * kv_len
kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0)
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
kv_indices = torch.arange(0, total_tokens).to(0).int()
page_size = 1
sm_scale = 1.0 / (
(128 + qk_rope_head_dim) ** 0.5
) # use head dimension before matrix absorption
# flashinfer_decode_wrapper.end_forward()
flashinfer_decode_wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
kv_lens,
num_local_heads,
kv_lora_rank,
qk_rope_head_dim,
page_size,
False,
sm_scale,
q_nope.dtype,
ckv.dtype,
)
for _ in range(warmup):
o = flashinfer_decode_wrapper.run(
q_nope, q_pe, ckv, kpe, return_lse=False
)
f = time_fwd(
flashinfer_decode_wrapper.run,
q_nope,
q_pe,
ckv,
kpe,
return_lse=False,
)
return f, o
return FlashinferAttention
def main(args):
config = ModelConfig(args.config_path)
fp16_tflops = 148
num_heads = config.num_attention_heads
kv_lora_rank = config.kv_lora_rank
# qk_nope_head_dim = config.qk_nope_head_dim
qk_rope_head_dim = config.qk_rope_head_dim
dtype = torch.bfloat16
kv_cache_dtype = dtype
if args.kv_cache_dtype == "bf16":
kv_cache_dtype = torch.bfloat16
elif args.kv_cache_dtype == "fp8":
kv_cache_dtype = torch.float8_e4m3fn
batch_kv_mapping = {
1: [1024, 4096, 8192, 16384, 32768, 65536, 131072],
16: [1024, 4096, 8192, 16384, 32768, 65536, 131072],
32: [1024, 4096, 8192, 16384, 32768, 65536, 131072],
64: [1024, 4096, 8192, 16384, 32768, 65536, 131072],
128: [1024, 4096, 8192, 16384, 32768, 65536],
256: [1024, 4096, 8192, 16384],
512: [1024, 4096, 8192],
}
configs = []
results = []
for batch_size, kv_len_range in batch_kv_mapping.items():
configs.extend([(batch_size, kv_len) for kv_len in kv_len_range])
attn_flashinfer = decode_attention_flashinfer().apply
for batch_size, kv_len in configs:
q_nope = torch.randn(
batch_size, num_heads, kv_lora_rank, dtype=dtype, device="cuda"
).to(kv_cache_dtype)
q_pe = torch.randn(
batch_size, num_heads, qk_rope_head_dim, dtype=dtype, device="cuda"
).to(kv_cache_dtype)
ckv = torch.randn(
batch_size * kv_len, 1, kv_lora_rank, dtype=dtype, device="cuda"
).to(kv_cache_dtype)
kpe = torch.randn(
batch_size * kv_len, 1, qk_rope_head_dim, dtype=dtype, device="cuda"
).to(kv_cache_dtype)
attn_core_gflops, other_gflops = get_mla_absorb_gflops(config, 1, kv_len)
attn_core_gflops = attn_core_gflops * batch_size
us_flashinfer, _ = attn_flashinfer(
q_nope,
q_pe,
ckv,
kpe,
batch_size,
kv_len,
num_heads,
kv_lora_rank,
qk_rope_head_dim,
)
mfu = attn_core_gflops * 1e3 / (fp16_tflops * us_flashinfer)
print(
"MLA",
" ",
num_heads,
" ",
kv_lora_rank,
" ",
batch_size,
" ",
kv_len,
" ",
us_flashinfer,
" ",
mfu,
)
results.append(
{
"dtype": "bf16",
"kv_dtype": args.kv_cache_dtype,
"batch_size": batch_size,
"kv_len": kv_len,
"latency_us": round(us_flashinfer, 3),
"mfu": round(mfu, 3),
}
)
df = pd.DataFrame(results)
df.to_csv("attention_benchmark.csv", index=False)
if __name__ == "__main__":
# calculate_diff()
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-path",
type=str,
help="The path of the hf model config.json",
required=True,
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["bf16", "fp8"],
default="bf16",
help="dtype of KV Cache,choices: bf16, fp8",
required=False,
)
args = parser.parse_args()
main(args)