From 6f8197a88c87cfb1665e5080b011e792d7cb0880 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 18 Mar 2026 04:29:19 +0000 Subject: [PATCH 1/6] fix only audio_server --- lightllm/server/audioserver/manager.py | 1 + lightllm/server/audioserver/model_infer/model_rpc.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index bb0a745302..ab0f76c63f 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -66,6 +66,7 @@ async def wait_to_model_ready(self): "rank_id": rank_id, "cache_port": self.cache_port, "data_type": self.args.data_type, + "init_shm_data": self.args.disable_vision, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index ad18ffd7f5..7c176ba090 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -43,7 +43,7 @@ def exposed_init_model(self, kvargs): set_current_device_id(torch.cuda.current_device()) - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=kvargs["init_shm_data"]) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) From 1992ef37fdacd3acedda3d1dd044557a8574941c Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 18 Mar 2026 04:30:35 +0000 Subject: [PATCH 2/6] 5090 not support tma --- lightllm/utils/device_utils.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index a1ed6ed950..bd3f1b1ed5 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -88,6 +88,16 @@ def is_musa(): return hasattr(torch.version, "musa") and torch.version.musa is not None +@lru_cache(maxsize=1) +def is_nvidia(): + return ( + torch.cuda.is_available() + and getattr(torch.version, "cuda", None) is not None + and getattr(torch.version, "hip", None) is None + and not is_musa() + ) + + @lru_cache(maxsize=None) def get_current_device_name(): if torch.cuda.is_available() or is_musa(): @@ -262,18 +272,23 @@ def set_sm_limit(percent: int, gpu_index=0): @lru_cache(maxsize=None) def triton_support_tensor_descriptor() -> bool: + if not is_nvidia(): + logger.info("triton tensor_descriptor requires NVIDIA CUDA GPU") + return False + try: from triton.tools.tensor_descriptor import TensorDescriptor + _ = TensorDescriptor - support_tma = torch.cuda.get_device_capability() >= (9, 0) + support_tma = torch.cuda.get_device_capability() >= (9, 0) and not is_5090_gpu() if support_tma: logger.info("triton support tensor_descriptor") return True - else: - assert False - except: - logger.info("triton not support tensor_descriptor") - return False + except Exception: + pass + + logger.info("triton not support tensor_descriptor") + return False @lru_cache(maxsize=None) From 951fcb4f84bd795861cf0b73c369fe57c3e0254a Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 18 Mar 2026 04:37:33 +0000 Subject: [PATCH 3/6] format --- lightllm/server/audioserver/model_infer/model_rpc.py | 4 +++- lightllm/utils/device_utils.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 7c176ba090..5364ba8458 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -43,7 +43,9 @@ def exposed_init_model(self, kvargs): set_current_device_id(torch.cuda.current_device()) - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=kvargs["init_shm_data"]) + self.cpu_embed_cache_client = CpuEmbedCacheClient( + create_meta_data=False, init_shm_data=kvargs["init_shm_data"] + ) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index bd3f1b1ed5..4585e383e5 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -278,6 +278,7 @@ def triton_support_tensor_descriptor() -> bool: try: from triton.tools.tensor_descriptor import TensorDescriptor + _ = TensorDescriptor support_tma = torch.cuda.get_device_capability() >= (9, 0) and not is_5090_gpu() From 0edadd96f65f2a70e548b6889770736b745827b2 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Wed, 18 Mar 2026 05:51:04 +0000 Subject: [PATCH 4/6] add support_tma() --- lightllm/utils/device_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 4585e383e5..d1c9915aad 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -270,10 +270,18 @@ def set_sm_limit(percent: int, gpu_index=0): return True +@lru_cache(maxsize=None) +def support_tma() -> bool: + return is_nvidia() and torch.cuda.get_device_capability() >= (9, 0) and not is_5090_gpu() + + @lru_cache(maxsize=None) def triton_support_tensor_descriptor() -> bool: - if not is_nvidia(): - logger.info("triton tensor_descriptor requires NVIDIA CUDA GPU") + if not support_tma(): + logger.info( + "triton tensor_descriptor requires NVIDIA Hopper or newer GPU (compute capability >= 9.0) " + "and is not supported on 5090" + ) return False try: @@ -281,10 +289,8 @@ def triton_support_tensor_descriptor() -> bool: _ = TensorDescriptor - support_tma = torch.cuda.get_device_capability() >= (9, 0) and not is_5090_gpu() - if support_tma: - logger.info("triton support tensor_descriptor") - return True + logger.info("triton support tensor_descriptor") + return True except Exception: pass From b15284e3981c90a98bad7500a4002e4c3cc7a6ac Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 18 Mar 2026 08:25:12 +0000 Subject: [PATCH 5/6] fix --- .../audioserver/model_infer/model_rpc.py | 3 ++- .../server/embed_cache/embed_cache_client.py | 24 ++++--------------- .../embed_cache/impl/naive_memory_cache.py | 2 +- .../visualserver/model_infer/model_rpc.py | 2 +- 4 files changed, 9 insertions(+), 22 deletions(-) diff --git a/lightllm/server/audioserver/model_infer/model_rpc.py b/lightllm/server/audioserver/model_infer/model_rpc.py index 5364ba8458..a8a2c39c3e 100644 --- a/lightllm/server/audioserver/model_infer/model_rpc.py +++ b/lightllm/server/audioserver/model_infer/model_rpc.py @@ -44,7 +44,8 @@ def exposed_init_model(self, kvargs): set_current_device_id(torch.cuda.current_device()) self.cpu_embed_cache_client = CpuEmbedCacheClient( - create_meta_data=False, init_shm_data=kvargs["init_shm_data"] + create_meta_data=False, + init_shm_data=False, ) except Exception as e: print("#" * 16) diff --git a/lightllm/server/embed_cache/embed_cache_client.py b/lightllm/server/embed_cache/embed_cache_client.py index 8c5b7f71ee..6fcc2d3783 100644 --- a/lightllm/server/embed_cache/embed_cache_client.py +++ b/lightllm/server/embed_cache/embed_cache_client.py @@ -24,11 +24,11 @@ def __init__(self, create_meta_data: bool, init_shm_data: bool): if create_meta_data: self.token_index_manager = MemoryManager(total_size=self.token_num) + + if init_shm_data: + self._create_shm_embed_kv_cache() else: - if init_shm_data: - self._create_shm_embed_kv_cache() - else: - self._attach_shm_cpu_embed_cache() + self._attach_shm_cpu_embed_cache() return def alloc_indexes(self, token_num: int) -> Optional["MemoryBlock"]: @@ -64,21 +64,7 @@ def _create_shm_embed_kv_cache(self): shm_ptr = create_shm_kv_cache_ptr( key=self.args.multi_modal_cache_shm_id, size=self.embed_cache_tensor_meta.calcu_size() ) - handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.embed_cache_tensor_meta.calcu_size()) - handle.wait() - numpy_array = np.frombuffer( - memoryview((ctypes.c_uint8 * self.embed_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), - dtype=np.uint8, - ) - # 将 NumPy 数组转换为 PyTorch 张量 - shape = ( - self.embed_cache_tensor_meta.token_num, - self.embed_cache_tensor_meta.layer_num, - self.embed_cache_tensor_meta.hidden_size, - ) - self.cpu_embed_cache_tensor = ( - torch.from_numpy(numpy_array).view(dtype=self.embed_cache_tensor_meta.data_type).view(shape) - ) + logger.info(f"create embed cache shm ptr: {shm_ptr}, size: {self.embed_cache_tensor_meta.calcu_size()}") return def _attach_shm_cpu_embed_cache(self): diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 61ba46d7c6..fbce108762 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -45,7 +45,7 @@ def __init__(self, args) -> None: self.token_id_range_start = 0 self.token_id_range_end = 0 self.use_config_server = self.args.config_server_host and self.args.config_server_port - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=True, init_shm_data=False) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=True, init_shm_data=True) def _check_and_set_new_id_range(self, alloced_token_num): need_update_range = self.token_id_range_start + alloced_token_num >= self.token_id_range_end diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 3e97f4de3e..6355ac2dbf 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -95,7 +95,7 @@ def exposed_init_model(self, kvargs): self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) From 3efc0374aacf5575d09ad3d73d491407750e4773 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 18 Mar 2026 08:28:28 +0000 Subject: [PATCH 6/6] fix --- lightllm/server/audioserver/manager.py | 1 - lightllm/utils/device_utils.py | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/server/audioserver/manager.py b/lightllm/server/audioserver/manager.py index ab0f76c63f..bb0a745302 100644 --- a/lightllm/server/audioserver/manager.py +++ b/lightllm/server/audioserver/manager.py @@ -66,7 +66,6 @@ async def wait_to_model_ready(self): "rank_id": rank_id, "cache_port": self.cache_port, "data_type": self.args.data_type, - "init_shm_data": self.args.disable_vision, } init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index d1c9915aad..43b10ec88b 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -90,12 +90,14 @@ def is_musa(): @lru_cache(maxsize=1) def is_nvidia(): - return ( + ans = ( torch.cuda.is_available() and getattr(torch.version, "cuda", None) is not None and getattr(torch.version, "hip", None) is None and not is_musa() ) + logger.info(f"device is_nvidia : {ans}") + return ans @lru_cache(maxsize=None) @@ -272,6 +274,7 @@ def set_sm_limit(percent: int, gpu_index=0): @lru_cache(maxsize=None) def support_tma() -> bool: + # 5090 关闭 tma feature,实际测试开了没啥用 return is_nvidia() and torch.cuda.get_device_capability() >= (9, 0) and not is_5090_gpu()