diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index c0e689735d4..c145542dab5 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -56,6 +56,7 @@ from fastdeploy.spec_decode import SpecMethod from fastdeploy.utils import print_gpu_memory_use from fastdeploy.worker.input_batch import InputBatch, reorder_split_prefill_and_decode +from fastdeploy.worker.tbo import GLOBAL_ATTN_BUFFERS if current_platform.is_iluvatar(): from fastdeploy.model_executor.ops.iluvatar import ( @@ -1530,7 +1531,7 @@ def _initialize_attn_backend(self) -> None: if envs.FD_DETERMINISTIC_MODE: decoder_block_shape_q = envs.FD_DETERMINISTIC_SPLIT_KV_SIZE - res_buffer = allocate_launch_related_buffer( + buffer_kwargs = dict( max_batch_size=self.scheduler_config.max_num_seqs, max_model_len=self.model_config.max_model_len, encoder_block_shape_q=encoder_block_shape_q, @@ -1540,8 +1541,13 @@ def _initialize_attn_backend(self) -> None: kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, ) + res_buffer = allocate_launch_related_buffer(**buffer_kwargs) self.share_inputs.update(res_buffer) + if int(os.getenv("USE_TBO", "0")) == 1: + for j in range(2): + GLOBAL_ATTN_BUFFERS[j] = allocate_launch_related_buffer(**buffer_kwargs) + # Get the attention backend attn_cls = get_attention_backend() attn_backend = attn_cls( diff --git a/fastdeploy/worker/tbo.py b/fastdeploy/worker/tbo.py index bcfec83353f..9854b009367 100644 --- a/fastdeploy/worker/tbo.py +++ b/fastdeploy/worker/tbo.py @@ -114,8 +114,6 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta, fd_config): end_bs += 1 if len(forward_meta.rotary_embs.shape) == 6: - max_bs = forward_meta.rotary_embs.shape[0] - assert max_bs == forward_meta.block_tables.shape[0] assert forward_meta.rotary_embs.shape[1:3] == [2, 1] assert forward_meta.rotary_embs.shape[4] == 1 res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs]