From cd59efe443e92644e03a7e965cad27f9bf107a5d Mon Sep 17 00:00:00 2001 From: Feng Li Date: Mon, 4 Sep 2023 18:53:53 +0000 Subject: [PATCH 1/3] masked_tokens uses session_length --- .../models/multi_gpu_gpt/ParallelGpt.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index 93b80ae6e..2423df5a2 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -158,7 +158,7 @@ void ParallelGpt::allocateBuffer(size_t batch_size, parent_ids_buf_ = (int*)(allocator_->reMalloc(parent_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); seq_limit_len_ = (uint32_t*)(allocator_->reMalloc(seq_limit_len_, sizeof(uint32_t) * batch_size, false)); tiled_masked_tokens_ = - (bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * memory_len, true)); + (bool*)(allocator_->reMalloc(tiled_masked_tokens_, sizeof(bool) * batchxbeam * max_session_len, true)); context_decoder_input_buf_ = (T*)(allocator_->reMalloc( context_decoder_input_buf_, sizeof(T) * batchxbeam * max_input_len * hidden_units_, false)); @@ -865,7 +865,7 @@ void ParallelGpt::forward(std::unordered_map* outp PUSH_RANGE("initialize output and parent ids"); cudaMemsetAsync(output_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); cudaMemsetAsync(parent_ids_buf_, 0, sizeof(int) * batch_size * beam_width * session_len, stream_); - cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * memory_len, stream_); + cudaMemsetAsync(tiled_masked_tokens_, false, sizeof(bool) * batch_size * beam_width * session_len, stream_); cudaMemsetAsync(tiled_total_padding_count_, 0, sizeof(int) * batch_size * beam_width, stream_); if (beam_width > 1) { cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * memory_len, stream_); @@ -1180,7 +1180,7 @@ void ParallelGpt::forward(std::unordered_map* outp PUSH_RANGE("mask padding tokens"); invokeMaskPaddingTokens(tiled_masked_tokens_, input_tensors->at("input_lengths").getPtr(), - memory_len, + session_len, max_input_length, initial_step, batch_size, @@ -1316,8 +1316,8 @@ void ParallelGpt::forward(std::unordered_map* outp {"masked_tokens", Tensor(MEMORY_GPU, TYPE_BOOL, - {local_batch_size * beam_width, memory_len}, - tiled_masked_tokens_ + id_offset * memory_len)}}); + {local_batch_size * beam_width, session_len}, + tiled_masked_tokens_ + id_offset * session_len)}}); if (beam_width > 1) { decoder_input_tensors.insert({"cache_indirection", Tensor(MEMORY_GPU, From 72319c6453772166dcfed22c505d1b18d626e08b Mon Sep 17 00:00:00 2001 From: Feng Li Date: Mon, 4 Sep 2023 19:55:42 +0000 Subject: [PATCH 2/3] masked_tokens uses session length everywhere --- .../kernels/decoder_masked_multihead_attention.h | 4 +++- .../decoder_masked_multihead_attention_template.hpp | 5 +++-- .../layers/attention_layers/DecoderSelfAttentionLayer.cc | 7 ++++++- .../models/multi_gpu_gpt/ParallelGptDecoder.cc | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h index 5a768184c..0d3e67546 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -80,8 +80,10 @@ struct Multihead_attention_params_base { int batch_size = 0; // The beam width int beam_width = 0; - // The sequence length. + // The cache length. int memory_max_len = 0; + // The whole sequence length, which includes context and output. + int session_len = 0; // The number of heads (H). int num_heads = 0; // The hidden dimension per head (Dh). diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp index 8e7cb92a2..038f35ba4 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -1219,6 +1219,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params::forward(TensorMap* output_tens // finished [batch_size] (optional) // total_padding_tokens [batch_size] (optional) // max_input_length [1] on cpu (optional) - // masked_tokens [batch_size, memory_len], (optional) + // masked_tokens [batch_size, session_len], (optional) // cache_indirection [batch_size / beam_width, beam_width, memory_max_len] (optional) // d_prefix_prompt_lengths [batch_size] (optional) // max_prefix_prompt_length [1] on cpu (optional) @@ -504,6 +507,7 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens const int batch_size = input_tensors->at("input_query").shape[0]; const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1; const int memory_max_len = output_tensors->at("key_cache").shape[3]; + const int session_len = masked_tokens != nullptr ? input_tensors->at("masked_tokens").shape[1] : 0; const int* d_prefix_prompt_lengths = input_tensors->getPtr("d_prefix_prompt_lengths", nullptr); const int max_prefix_prompt_length = input_tensors->getVal("max_prefix_prompt_length", 0); @@ -596,6 +600,7 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens rotary_embedding_dim_, neox_rotary_style_, memory_max_len, + session_len, d_prefix_prompt_lengths, max_prefix_prompt_length, input_tensors->getVal("max_input_length", 0), diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc index 173c87b46..0fe65c94d 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc @@ -269,7 +269,7 @@ void ParallelGptDecoder::forward(std::unordered_map* // cache_indirection [local_batch_size / beam_width, beam_width, memory_len] // Here, local_batch_size contains the beam_width, so local_batch_size / beam_width // is real local_batch_size. (optional.) - // masked_tokens [local_batch_size, memory_len] + // masked_tokens [local_batch_size, session_len] // linear_bias_slopes [head_num], optional // output tensors: From 499149e5eec1bf45cd95e86e06addbbe77bdd71c Mon Sep 17 00:00:00 2001 From: Feng Li Date: Tue, 24 Oct 2023 18:37:15 +0000 Subject: [PATCH 3/3] Important KV Cache in auto-regressive decoder. --- .../decoder_masked_multihead_attention.h | 4 + ...er_masked_multihead_attention_template.hpp | 154 ++++++++++++++++-- src/fastertransformer/kernels/gpt_kernels.cu | 20 +++ src/fastertransformer/kernels/gpt_kernels.h | 5 + .../DecoderSelfAttentionLayer.cc | 15 ++ .../DecoderSelfAttentionLayer.h | 2 + .../models/multi_gpu_gpt/ParallelGpt.cc | 56 ++++++- .../models/multi_gpu_gpt/ParallelGpt.h | 3 + .../multi_gpu_gpt/ParallelGptDecoder.cc | 45 ++++- .../tf_op/decoder/FusedSelfAttentionOp.cc | 2 + 10 files changed, 290 insertions(+), 16 deletions(-) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h index 0d3e67546..2385a572b 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -93,6 +93,10 @@ struct Multihead_attention_params_base { bool neox_rotary_style = false; // The maximum length of input sentences. int max_input_length = 0; + // The number of oldest cache element to pick the least important to replace when cache is full. If 0, it will fall back to circular cache. + int important_kv_cache_size = 0; + // The buffer to store the indices of the keys and values in the cache. The shape is [B, H, memory_max_len]. + int* kv_indices = nullptr; // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? int timestep = 0; // The current timestep of each sentences (support different timestep for different sentences) diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp index 038f35ba4..d50ec13e6 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -1159,8 +1159,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params(smem_); - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + // The shared memory buffers for the block-wide reductions. One for max, one for sum, one for min. + __shared__ float red_smem[WARPS_PER_BLOCK * 3]; + // The shared memory buffers for the qk_min index block-wide reduction. + __shared__ int red_int_smem[WARPS_PER_BLOCK]; // A vector of Q or K elements for the current timestep. using Qk_vec_k = typename Qk_vec_k_::Type; // with kernel-used precision @@ -1213,6 +1215,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params= params.memory_max_len && params.important_kv_cache_size > 0 ? + params.memory_max_len - params.important_kv_cache_size : + params.memory_max_len; + + // Make sure the following params have correct values in all 5 cases: + // 1. full cache: tlength < params.memory_max_len + // 1.1 params.important_kv_cache_size == 0 + // 1.2 params.important_kv_cache_size > 0 + // 2. fifo cache: tlength >= params.memory_max_len, params.important_kv_cache_size == 0 + // 3. important cache: tlength >= params.memory_max_len, params.important_kv_cache_size == params.memory_max_len + // 4. hybrid cache: tlength >= params.memory_max_len, 0 < params.important_kv_cache_size < params.memory_max_len + + // const int tlength_circ_offset = tlength > params.memory_max_len ? params.important_kv_cache_size : 0; + // const int tlength_circ = + // fifo_cache_size > 0 ? tlength % fifo_cache_size + tlength_circ_offset : tlength % params.memory_max_len; + + const int tlength_circ_with_important_cache = + fifo_cache_size > 0 && tlength >= params.memory_max_len ? + (tlength - params.important_kv_cache_size) % fifo_cache_size + params.important_kv_cache_size: + tlength % params.memory_max_len; + + // tlength_circ is the index relative to the beginning of the k/v cache. + const int tlength_circ = + tlength >= params.memory_max_len ? tlength_circ_with_important_cache : tlength % params.memory_max_len; + const bool do_important_cache = handle_kv && params.important_kv_cache_size > 0 && tlength >= params.memory_max_len; + int* kv_index_buffer = params.kv_indices == nullptr ? nullptr : params.kv_indices + bhi * params.memory_max_len; + + assert(!(do_important_cache && params.kv_indices == nullptr)); // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. const bool is_masked = tidx >= QK_VECS_PER_WARP; @@ -1240,6 +1272,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params(&q_smem[tidx * QK_VEC_SIZE]) = q; @@ -1413,9 +1454,17 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params(¶ms.k_cache[offset]) = vec_conversion(k); + if (!do_important_cache) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + } + } + else if (fifo_cache_size > 0) { + fifo_out_k = + (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) ? + vec_conversion(*reinterpret_cast(¶ms.k_cache[offset])) : + fifo_out_k; } } @@ -1584,11 +1633,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params(params.linear_bias_slopes[hi], dist); } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + if (!is_mask && do_important_cache && ti_circ < params.important_kv_cache_size) { + qk_min_idx = qk_min > qk ? ti_circ : qk_min_idx; + qk_min = qk_min > qk ? qk : qk_min; + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = qk; } } @@ -1600,6 +1654,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params= THREADS_PER_KEY; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + if (do_important_cache) { + float temp_qk_min = __shfl_xor_sync(uint32_t(-1), qk_min, mask); + int temp_qk_min_idx = __shfl_xor_sync(uint32_t(-1), qk_min_idx, mask); + qk_min_idx = qk_min > temp_qk_min ? temp_qk_min_idx : qk_min_idx; + qk_min = fminf(qk_min, temp_qk_min); + } } // Decompose the thread index into warp and lane. @@ -1609,6 +1669,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params= 1; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + if (do_important_cache) { + float temp_qk_min = __shfl_xor_sync(uint32_t(-1), qk_min, mask); + int temp_qk_min_idx = __shfl_xor_sync(uint32_t(-1), qk_min_idx, mask); + qk_min_idx = qk_min > temp_qk_min ? temp_qk_min_idx : qk_min_idx; + qk_min = fminf(qk_min, temp_qk_min); + } } // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + if (do_important_cache) { + qk_min_idx = __shfl_sync(uint32_t(-1), qk_min_idx, 0); + } // Compute the logits and start the sum. float sum = 0.f; @@ -1686,6 +1763,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params 0) { + // Make sure the same group of threads handles the cache storage as later in the process. + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + fifo_out_v = vec_conversion(*reinterpret_cast(&v_cache[tlength_circ * Dh])); + } + } + // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; zero(v_bias); @@ -1815,6 +1901,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + if (!do_important_cache || fifo_cache_size > 0) { + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + if (kv_index_buffer != nullptr && vi == 0) { + kv_index_buffer[tlength_circ] = tlength; + } + } + if (do_important_cache) { + assert(qk_min_idx < params.important_kv_cache_size); + assert(qk_min_idx >= 0); + // TO CHECK: If the whole cache has is_mask, qk_min_idx will be -1, and it will be writing to invalid cache space. + if (fifo_cache_size > 0) { + *reinterpret_cast(&v_cache[qk_min_idx * Dh]) = vec_conversion(fifo_out_v); + } else { + *reinterpret_cast(&v_cache[qk_min_idx * Dh]) = vec_conversion(v); + } + if (vi == 0) { + kv_index_buffer[qk_min_idx] = tlength - fifo_cache_size; + } + } } // Initialize the output value with the current timestep. @@ -1881,6 +1987,32 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params= 0); + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int base_offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B; + int fifo_offset = base_offset + tlength_circ * QK_ELTS_IN_16B + ci; + int important_offset = base_offset + qk_min_idx * QK_ELTS_IN_16B + ci; + + if (fifo_cache_size > 0) { + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[fifo_offset]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[important_offset]) = vec_conversion(fifo_out_k); + } + } else { + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[important_offset]) = vec_conversion(k); + } + } + } + // Make sure we can start writing to shared memory. __syncthreads(); diff --git a/src/fastertransformer/kernels/gpt_kernels.cu b/src/fastertransformer/kernels/gpt_kernels.cu index 7dc9af620..5f5880aa1 100644 --- a/src/fastertransformer/kernels/gpt_kernels.cu +++ b/src/fastertransformer/kernels/gpt_kernels.cu @@ -27,6 +27,26 @@ namespace fastertransformer { +__global__ void initiate_indices(int* kv_indices, const int loop_size, const size_t total_length) +{ + int idx = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (int i = idx; i < total_length; i += stride) { + kv_indices[i] = i % loop_size; + } +} + +void invokeInitiateIndices(int* kv_indices, + const int loop_size, + const size_t total_length, + cudaStream_t stream) +{ + const int block_size = 256; + const int grid_size = min(65535, static_cast((total_length + block_size - 1) / block_size)); + + initiate_indices<<>>(kv_indices, loop_size, total_length); +} + // PROMPT_SRC: 0 --> no prompts, 1 --> from loaded prompts, 2 --> from request prompts template __global__ void start_id_embedding_position_lookups_kernel(T* from_tensor, diff --git a/src/fastertransformer/kernels/gpt_kernels.h b/src/fastertransformer/kernels/gpt_kernels.h index d78224e0a..84e651647 100644 --- a/src/fastertransformer/kernels/gpt_kernels.h +++ b/src/fastertransformer/kernels/gpt_kernels.h @@ -25,6 +25,11 @@ namespace fastertransformer { +void invokeInitiateIndices(int* kv_indices, + const int loop_size, + const size_t total_length, + cudaStream_t stream); + template struct inputIdsEmbeddingLookupPosEncodingSoftPromptParam { T* from_tensor; diff --git a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc index 622105a3f..4d8f90fd4 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc @@ -63,6 +63,8 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int* ia3_tasks, const T* ia3_key_weights, const T* ia3_value_weights, + const int important_kv_cache_size, + int* kv_indices, const float* qkv_scale_out, const float* attention_out_scale, const int int8_mode, @@ -131,6 +133,8 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.linear_bias_slopes = reinterpret_cast(linear_bias_slopes); } params.max_input_length = max_input_len; + params.important_kv_cache_size = important_kv_cache_size; + params.kv_indices = kv_indices; params.ia3_tasks = ia3_tasks; params.ia3_key_weights = reinterpret_cast(ia3_key_weights); @@ -178,6 +182,8 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int* ia3_tasks, \ const T* ia3_key_weights, \ const T* ia3_value_weights, \ + const int important_kv_cache_size, \ + int* kv_indices, \ const float* qkv_scale_out, \ const float* attention_out_scale, \ const int int8_mode, \ @@ -477,11 +483,13 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens // relative_attention_bias [1, head_num, step, step] or [1, head_num, max_seq_len, max_seq_len] (optional) // linear_bias_slopes [head_num] (optional) // ia3_tasks [batch_size] (optional) + // important_kv_cache_size [1] on cpu (optional) // output tensors: // attention_output [batch_size, d_model_], // key_cache [batch, local_head_num, size_per_head // x, memory_max_len, x] // value_cache [batch, local_head_num, memory_max_len, size_per_head] + // kv_indices [batch, local_head_num, memory_max_len] (optional) FT_LOG_DEBUG(__PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5 || output_tensors->at("key_cache").shape.size() == 3); @@ -499,14 +507,19 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens input_tensors->isExist("relative_attention_bias") ? input_tensors->at("relative_attention_bias").shape[3] : 0; const T* linear_bias_slopes = input_tensors->getPtr("linear_bias_slopes", nullptr); const bool has_ia3 = input_tensors->isExist("ia3_tasks"); + const int important_kv_cache_size = input_tensors->getVal("important_kv_cache_size", 0); T* attention_out = output_tensors->getPtr("hidden_features"); T* key_cache = output_tensors->getPtr("key_cache"); T* value_cache = output_tensors->getPtr("value_cache"); + int* kv_indices = output_tensors->getPtr("kv_indices", nullptr); const int batch_size = input_tensors->at("input_query").shape[0]; const int beam_width = cache_indir != nullptr ? input_tensors->at("cache_indirection").shape[1] : 1; const int memory_max_len = output_tensors->at("key_cache").shape[3]; + if (kv_indices != nullptr) { + FT_CHECK(output_tensors->at("kv_indices").shape[2] == memory_max_len); + } const int session_len = masked_tokens != nullptr ? input_tensors->at("masked_tokens").shape[1] : 0; const int* d_prefix_prompt_lengths = input_tensors->getPtr("d_prefix_prompt_lengths", nullptr); @@ -613,6 +626,8 @@ void DecoderSelfAttentionLayer::forward(TensorMap* output_tens input_tensors->getPtr("ia3_tasks", nullptr), has_ia3 ? attention_weights->ia3_key_weight.kernel : nullptr, has_ia3 ? attention_weights->ia3_value_weight.kernel : nullptr, + important_kv_cache_size, + kv_indices, int8_mode_ == 2 ? attention_weights->query_weight.scale_out : nullptr, int8_mode_ == 2 ? attention_weights->attention_output_weight.scale : nullptr, int8_mode_, diff --git a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h index 1b1644e64..92ef5d432 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h +++ b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.h @@ -180,6 +180,8 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int* ia3_tasks, const T* ia3_key_weights, const T* ia3_value_weights, + const int important_kv_cache_size, + int* kv_indices, const float* qkv_scale_out, const float* attention_out_scale, const int int8_mode, diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc index 2423df5a2..304599a62 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.cc @@ -98,6 +98,7 @@ void ParallelGpt::allocateBuffer(size_t batch_size, size_t max_session_len, size_t memory_len, size_t max_input_len, + size_t important_kv_cache_size, bool is_return_context_cum_log_probs) { FT_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -138,6 +139,11 @@ void ParallelGpt::allocateBuffer(size_t batch_size, (int*)(allocator_->reMalloc(cache_indirections_[0], sizeof(int) * batchxbeam * memory_len * 2, true)); cache_indirections_[1] = cache_indirections_[0] + batchxbeam * memory_len; } + if (important_kv_cache_size > 0) { + const size_t self_kv_index_size = + (num_layer_ / pipeline_para_.world_size_) * batchxbeam * local_head_num_ * memory_len; + kv_indices_ = (int*)(allocator_->reMalloc(kv_indices_, sizeof(int) * self_kv_index_size, true)); + } tiled_input_ids_buf_ = (int*)(allocator_->reMalloc(tiled_input_ids_buf_, sizeof(int) * batchxbeam * max_session_len, true)); @@ -224,6 +230,9 @@ void ParallelGpt::freeBuffer() if (cache_indirections_[0] != nullptr) { allocator_->free((void**)(&cache_indirections_)[0]); } + if (kv_indices_ != nullptr) { + allocator_->free((void**)(&kv_indices_)); + } allocator_->free((void**)(&tiled_input_ids_buf_)); allocator_->free((void**)(&tiled_input_lengths_buf_)); @@ -607,6 +616,7 @@ void ParallelGpt::forward(std::unordered_map* outp // top_p_decay [batch_size] on gpu, float, optional // top_p_min [batch_size] on gpu, float, optional // top_p_reset_ids [batch_size] on gpu, uint32, optional + // important_kv_cache_size [1] on cpu, uint32, optional // output_tensors: // output_ids [batch_size, beam_width, max_output_seq_len] @@ -730,6 +740,7 @@ void ParallelGpt::forward(std::unordered_map* outp const size_t gen_len = input_tensors->at("output_seq_len").max() + limit_len_offset; size_t session_len = 0; + if (continue_gen) { session_len = session_len_; // Record the size of allocated buffer in previous round. } @@ -741,8 +752,10 @@ void ParallelGpt::forward(std::unordered_map* outp } session_len_ = session_len; FT_CHECK_WITH_INFO( - gen_len + initial_step <= session_len, - fmtstr("Session size too low (%d) vs. total output size (%d)", session_len, gen_len + initial_step)); + gen_len <= session_len, + fmtstr("Session size too low (%d) vs. total output size (%d).", + session_len, gen_len) + ); size_t memory_len = 0; if (continue_gen) { memory_len = memory_len_; // Record the size of allocated buffer in previous round. @@ -755,9 +768,24 @@ void ParallelGpt::forward(std::unordered_map* outp } memory_len_ = memory_len; /* TODO: could remove this constraint by changing how context decoder operates */ - FT_CHECK_WITH_INFO(max_input_length <= memory_len, + FT_CHECK_WITH_INFO(continue_gen || max_input_length <= memory_len, fmtstr("Memory size too low (%d) vs. input length (%d)", memory_len, max_input_length)); + size_t important_kv_cache_size = 0; + if (continue_gen) { + important_kv_cache_size = important_kv_cache_size_; + } else if (input_tensors->find("important_kv_cache_size") != input_tensors->end()) { + important_kv_cache_size = input_tensors->at("important_kv_cache_size").getVal(); + } + if (important_kv_cache_size > memory_len) { + FT_LOG_WARNING("important_kv_cache_size (%d) is less than memory_len (%d). " + "Setting important_kv_cache_size to memory_len.", + important_kv_cache_size, + memory_len); + important_kv_cache_size = memory_len; + } + important_kv_cache_size_ = important_kv_cache_size; + if (memory_len < session_len) { FT_LOG_WARNING("memory_len (%d) is less than session_len (%d). " "Note that this reduces the memory cost of k/v cache, but may hurt the accuracy.", @@ -796,6 +824,7 @@ void ParallelGpt::forward(std::unordered_map* outp session_len, memory_len, max_input_length + max_prefix_soft_prompt_length, + important_kv_cache_size, is_return_context_cum_log_probs); sync_check_cuda_error(); } @@ -830,6 +859,14 @@ void ParallelGpt::forward(std::unordered_map* outp POP_RANGE; } + if (important_kv_cache_size > 0) { + PUSH_RANGE("initiate indices"); + const size_t indices_length = (num_layer_ / pipeline_para_.world_size_) * batch_size * beam_width + * local_head_num_ * memory_len; + invokeInitiateIndices(kv_indices_, memory_len, indices_length, stream_); + POP_RANGE; + } + if (continue_gen) { PUSH_RANGE("input tiling and init"); invokeTileGptInputs(tiled_input_ids_buf_, @@ -1317,7 +1354,8 @@ void ParallelGpt::forward(std::unordered_map* outp Tensor(MEMORY_GPU, TYPE_BOOL, {local_batch_size * beam_width, session_len}, - tiled_masked_tokens_ + id_offset * session_len)}}); + tiled_masked_tokens_ + id_offset * session_len)}, + {"important_kv_cache_size", Tensor(MEMORY_CPU, TYPE_UINT32, {1}, &important_kv_cache_size)}}); if (beam_width > 1) { decoder_input_tensors.insert({"cache_indirection", Tensor(MEMORY_GPU, @@ -1342,6 +1380,16 @@ void ParallelGpt::forward(std::unordered_map* outp decoder_output_buf_ + hidden_units_offset)}, {"key_cache", Tensor(MEMORY_GPU, data_type, self_k_cache_shape, key_cache_)}, {"value_cache", Tensor(MEMORY_GPU, data_type, self_v_cache_shape, value_cache_)}}); + if (important_kv_cache_size > 0) { + decoder_output_tensors.insert({"kv_indices", + Tensor(MEMORY_GPU, + TYPE_UINT32, + {num_layer_ / pipeline_para_.world_size_, + local_batch_size * beam_width, + local_head_num_, + memory_len}, + kv_indices_)}); + } gpt_decoder_->forward( &decoder_output_tensors, &decoder_input_tensors, &gpt_weights->decoder_layer_weights); diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h index ea24de2d3..4d6839d5e 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGpt.h @@ -92,6 +92,7 @@ class ParallelGpt: public BaseLayer { size_t max_seq_len, size_t memory_len, size_t max_input_len, + size_t important_kv_cache_size, bool is_return_context_cum_log_probs); void freeBuffer() override; @@ -111,6 +112,7 @@ class ParallelGpt: public BaseLayer { int step_; size_t session_len_; size_t memory_len_; + size_t important_kv_cache_size_; int* tiled_total_padding_count_ = nullptr; T* padded_embedding_kernel_; @@ -139,6 +141,7 @@ class ParallelGpt: public BaseLayer { T* key_cache_; T* value_cache_; int* cache_indirections_[2] = {nullptr, nullptr}; + int* kv_indices_ = nullptr; int* start_ids_buf_; int* end_ids_buf_; diff --git a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc index 0fe65c94d..bdbcdc1eb 100644 --- a/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc +++ b/src/fastertransformer/models/multi_gpu_gpt/ParallelGptDecoder.cc @@ -271,11 +271,13 @@ void ParallelGptDecoder::forward(std::unordered_map* // is real local_batch_size. (optional.) // masked_tokens [local_batch_size, session_len] // linear_bias_slopes [head_num], optional + // important_kv_cache_size [1] on cpu (optional) // output tensors: // decoder_output [local_batch_size, hidden_dimension], // key_cache [num_layer, batch_size, head_num, size_per_head // x, memory_len, x] // value_cache [num_layer, batch_size, head_num, memory_len, size_per_head] + // kv_indices [num_layer, batch_size, head_num, memory_len] (optional) FT_LOG_DEBUG(__PRETTY_FUNCTION__); @@ -298,9 +300,17 @@ void ParallelGptDecoder::forward(std::unordered_map* const int ite = input_tensors->at("ite").getVal(); + const auto important_kv_cache_size_tensor = input_tensors->find("important_kv_cache_size"); + const size_t important_kv_cache_size = + important_kv_cache_size_tensor != input_tensors->end() ? + important_kv_cache_size_tensor->second.getVal() : + 0; + Tensor k_cache = output_tensors->at("key_cache"); Tensor v_cache = output_tensors->at("value_cache"); + const auto kv_indices_tensor = output_tensors->find("kv_indices"); + // The resize of the key cache buffer by // (local_batch_size, local_head_num, size_per_head // x, max_seq_len, x) where x is constant. std::vector self_k_cache_size(k_cache.shape.begin() + 2, k_cache.shape.end()); @@ -310,6 +320,16 @@ void ParallelGptDecoder::forward(std::unordered_map* std::vector self_v_cache_size(v_cache.shape.begin() + 2, v_cache.shape.end()); self_v_cache_size.insert(self_v_cache_size.begin(), local_batch_size); + // The resize of the kv_indices buffer by (local_batch_size, local_head_num, important_kv_cache_size). + std::vector self_kv_indices_size; + if (kv_indices_tensor != output_tensors->end()){ + self_kv_indices_size.insert( + self_kv_indices_size.begin(), + kv_indices_tensor->second.shape.begin() + 2, + kv_indices_tensor->second.shape.end()); + self_kv_indices_size.insert(self_kv_indices_size.begin(), local_batch_size); + } + const auto activation_in_type = int8_mode_ == 2 ? TYPE_INT8 : data_type; const auto activation_out_type = data_type; @@ -368,7 +388,9 @@ void ParallelGptDecoder::forward(std::unordered_map* {"total_padding_tokens", input_tensors->at("total_padding_tokens")}, {"max_input_length", input_tensors->at("max_input_length")}, {"step", input_tensors->at("step")}, - {"masked_tokens", input_tensors->at("masked_tokens")}}; + {"masked_tokens", input_tensors->at("masked_tokens")}, + {"important_kv_cache_size", input_tensors->at("important_kv_cache_size")}, + }; if (input_tensors->count("cache_indirection")) { self_attention_input_tensors.insert("cache_indirection", input_tensors->at("cache_indirection")); } @@ -392,6 +414,27 @@ void ParallelGptDecoder::forward(std::unordered_map* {"key_cache", Tensor(MEMORY_GPU, data_type, self_k_cache_size, k_cache.getPtrWithOffset(cache_offset))}, {"value_cache", Tensor(MEMORY_GPU, data_type, self_v_cache_size, v_cache.getPtrWithOffset(cache_offset))}}; + if (kv_indices_tensor != output_tensors->end()){ + size_t kv_indices_offset = l - getFirstLayerParallelId(); + for (auto t = kv_indices_tensor->second.shape.begin() + 1; + t != kv_indices_tensor->second.shape.end(); + ++t) { + kv_indices_offset *= *t; + }; + size_t ite_indices_offset = ite * local_batch_size; + for (auto t = kv_indices_tensor->second.shape.begin() + 2; + t != kv_indices_tensor->second.shape.end(); + ++t) { + ite_indices_offset *= *t; + } + kv_indices_offset += ite_indices_offset; + self_attention_output_tensors.insert( + "kv_indices", + Tensor(MEMORY_GPU, + TYPE_INT32, + self_kv_indices_size, + kv_indices_tensor->second.getPtrWithOffset(kv_indices_offset))); + } self_attention_layer_->forward( &self_attention_output_tensors, &self_attention_input_tensors, &layer_weight->self_attention_weights); diff --git a/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc b/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc index 90f0ae257..97167bc45 100644 --- a/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc +++ b/src/fastertransformer/tf_op/decoder/FusedSelfAttentionOp.cc @@ -172,6 +172,8 @@ class FusedQkvMultiHeadAttentionOp: public BaseOp { (const int*)nullptr, // ia3 tasks (const DataType_*)nullptr, // ia3 key weights (const DataType_*)nullptr, // ia3 value weights + 0, // important_kv_cache_size + (int*)nullptr, // kv_indices (const float*)nullptr, // int8 scale in (const float*)nullptr, // int8 scale out 0, // int8 mode