The original definition of dynamic_batching.MultiHeadAttention refers to here.
The original definition of dynamic_batching.KeyValueCache refers to here.
For dynamic_batching.MultiHeadCacheAttention, it is just fuse dynamic_batching.MultiHeadAttention and dynamic_batching.KeyValueCache together.
Number of heads
Dimension of each head, where
Whether apply casual mask when sequence length > 1.
Whether apply alibi mask within the operator. Do not need to set alibi mask in attn_mask when it is True
For Grouped-Query Attention. If num_kv_heads and num_heads are not equal, we should repeat key and value num_heads/num_kv_heads times before applying num_heads must be divisible by num_kv_heads. Default is 0, and at this point, num_heads is used as num_kv_heads.
Number of attention layers.
Attention layer index for cache and scale.
Quantize bit for cache compression. For example, 8 means int8 compression. 0 means disabled.
Quantize scale shared group size.
Define cache indexing mode. Default is zero.
- When
cache_modeis0, cache is indexed by offset mode. Shape ofcachestartsis$(B)$ . For each batch$b$ ,cachestarts[b]mapping cache begining index in$MaxT$ ofcacheandscale. Note thatcachestarts[b+1]-cachestarts[b]can not calculate out the cache length of batch$b$ . - When
cache_modeis1,cacheis indexed by page table mode, which called Paged Attention. Shape ofcachestartsis$(B, MaxP)$ . For each batch$b$ ,cachestarts[b, :]contains pages' begining index in$MaxT$ ofcacheandscale.
Example forbatch = 2, page_size = 256:$$cachestarts=[[0,256,\cdots],[1024,2048,\cdots]]$$
Define data layout of cache and scale. Default is zero.
Meaning of numbers:
-
0:$cache(MaxT,L,2,H,Dh)$ and$scale(MaxT,L,2,H,Dh/quant\_group)$ -
1:$cache(L,MaxT,2,H,Dh)$ and$scale(L,MaxT,2,H,Dh/quant\_group)$ -
2:$cache(L,2,MaxT,H,Dh)$ and$scale(L,2,MaxT,H,Dh/quant\_group)$ -
3:$cache(L,2,H,MaxT,Dh)$ and$scale(L,2,H,MaxT,Dh/quant\_group)$
Page size in Paged Attention(when cache_mode is 1)
Input Query tensor
Shape:
Input Key tensor
Shape:
Input Value tensor
Shape:
seqstarts[:B] contains the position of the first token in query for each batch. And seqstarts[B] contains the total length of query.
Note that seqstarts[b+1]-seqstarts[b] can calculate out the sequence length of batch
Shape:
kvstarts[:B] contains the position of the first token in key = cat(past_key, current_key) and value = cat(past_value, current_value) for each batch, where key and value are originally provided by operator KeyValueCache. And kvstarts[B] contains the total length of key and value.
Note that kvstarts[b+1]-kvstarts[b] can calculate out the key and value length of batch
Shape:
Indexing cache position in cache and scale. Behavior is determinated by cache_mode.
Shape:
Sequence position where current_key and current_value begining to store of each batch.
Shape:
Describe how many batches in front are being decoded, those who are not need causal mask.
Maximum sequence length of query, equal to max(seqstarts[1:]-seqstarts[:B]). For parallel computing.
Maximum sequence length of key and value, equal to max(kvstarts[1:]-kvstarts[:B]). For parallel computing.
Shape: Determinated by cache_layout.
Contains key and value caches of attention layer. When cache_layout is 0, subspace
Shape: determinate by cache_layout.
Contains key and value cache quantize scales of attention layer. When cache_layout is 0, subspace quant_bit is not zero. Data in this tensor will be modified.
Optional custom mask.
seqlens=seqstarts[1:]-seqstarts[:B] is a sequence contains length of query for each batch.
kvlens=kvstarts[1:]-kvstarts[:B] is a sequence contains length of key and value for each batch.
Note: The last dim of mask could be bigger than
Shape:
Output feature of attention result
Shape: