diff --git a/sglang-LenVM/python/sglang/srt/layers/sampler.py b/sglang-LenVM/python/sglang/srt/layers/sampler.py index ccd4bfd..4677266 100644 --- a/sglang-LenVM/python/sglang/srt/layers/sampler.py +++ b/sglang-LenVM/python/sglang/srt/layers/sampler.py @@ -17,6 +17,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda, is_npu from sglang.srt.lvm.lvm_guided_sampling import LvmGuidedSampler +from sglang.srt.lvm.timing import get_timer if is_cuda(): from sgl_kernel import ( @@ -45,6 +46,7 @@ def __init__(self, model_runner=None): self.lvm_guided_sampler = LvmGuidedSampler.from_server_args( get_global_server_args(), model_runner=model_runner ) + self._timer = get_timer() if is_dp_attention_enabled(): self.tp_sync_group = get_attention_tp_group().device_group @@ -134,16 +136,27 @@ def forward( positions: The positions of the tokens in the sequence. Used for deterministic sampling to get the unique seed for each position. """ + timer = self._timer + t_forward = timer.section_start("t_sampler_total_ms") + guided_applied = False + batch_size_meta = int(logits_output.next_token_logits.shape[0]) + logits = logits_output.next_token_logits # Preprocess logits (custom processors and NaN handling) + t_preprocess = timer.section_start("t_preprocess_logits_ms") logits = self._preprocess_logits(logits, sampling_info) + timer.section_end("t_preprocess_logits_ms", t_preprocess) if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling + t_sample = timer.section_start("t_sample_ms") batch_next_token_ids = torch.argmax(logits, -1) + timer.section_end("t_sample_ms", t_sample) if return_logprob: + t_logprob = timer.section_start("t_logprob_ms") logprobs = torch.nn.functional.log_softmax(logits, dim=-1) + timer.section_end("t_logprob_ms", t_logprob) else: can_sample_directly_from_probs = ( not sampling_info.need_top_p_sampling @@ -164,6 +177,7 @@ def forward( ) # Post process logits + t_pre_lvm = timer.section_start("t_pre_lvm_ms") logits.div_(sampling_info.temperatures) # Per-token temperature scaling: for boosted tokens, divide their @@ -178,34 +192,48 @@ def forward( return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB ): logits[:] = torch.softmax(logits, dim=-1) + timer.section_end("t_pre_lvm_ms", t_pre_lvm) probs = logits del logits - guided_applied = False + guided_sample_result = None if self.lvm_guided_sampler is not None and sampling_info.reqs is not None: - guided_probs = self.lvm_guided_sampler.apply( + t_lvm = timer.section_start("t_lvm_apply_outer_ms") + guided_sample_result = self.lvm_guided_sampler.sample_token_ids( probs, sampling_info.reqs, sampling_info.temperatures, sampling_info.top_ps, sampling_info.top_ks, sampling_info.min_ps, + sampling_info.sampling_seed, + positions, ) - if guided_probs is not None: - probs = guided_probs - guided_applied = True - - if guided_applied: - can_sample_directly_from_probs = True + timer.section_end("t_lvm_apply_outer_ms", t_lvm) - if can_sample_directly_from_probs: + if ( + guided_sample_result is not None + and guided_sample_result.row_indices.numel() == probs.shape[0] + ): + batch_next_token_ids = torch.empty( + probs.shape[0], dtype=torch.int32, device=probs.device + ) + batch_next_token_ids[guided_sample_result.row_indices] = ( + guided_sample_result.token_ids + ) + guided_sample_result = None + guided_applied = True + elif can_sample_directly_from_probs: # when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs + t_sample = timer.section_start("t_sample_ms") batch_next_token_ids = sampling_from_probs_torch( probs, sampling_seed=sampling_info.sampling_seed, positions=positions, ) + timer.section_end("t_sample_ms", t_sample) else: + t_sample = timer.section_start("t_sample_ms") if get_global_server_args().sampling_backend == "flashinfer": if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) @@ -244,8 +272,18 @@ def forward( raise ValueError( f"Invalid sampling backend: {get_global_server_args().sampling_backend}" ) + timer.section_end("t_sample_ms", t_sample) + + if guided_sample_result is not None: + t_override = timer.section_start("t_guided_override_ms") + batch_next_token_ids[guided_sample_result.row_indices] = ( + guided_sample_result.token_ids + ) + timer.section_end("t_guided_override_ms", t_override) + guided_applied = True if return_logprob: + t_logprob = timer.section_start("t_logprob_ms") if get_global_server_args().rl_on_policy_target is not None: logprobs = logprobs_via_logsoftmax_kernel del logprobs_via_logsoftmax_kernel @@ -257,6 +295,7 @@ def forward( del probs_without_temp_scaling else: logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min) + timer.section_end("t_logprob_ms", t_logprob) # Attach logprobs to logits_output (in-place modification) if return_logprob: @@ -285,12 +324,21 @@ def forward( # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + t_tp_sync = timer.section_start("t_tp_sync_ms") torch.distributed.all_reduce( batch_next_token_ids, op=dist.ReduceOp.MIN, group=self.tp_sync_group, ) + timer.section_end("t_tp_sync_ms", t_tp_sync) + timer.section_end("t_sampler_total_ms", t_forward) + timer.set_meta( + lvm_active=bool(guided_applied), + batch_size=batch_size_meta, + is_greedy=bool(sampling_info.is_all_greedy), + ) + timer.flush_step() return batch_next_token_ids def compute_logprobs_only( diff --git a/sglang-LenVM/python/sglang/srt/lvm/entropy_kernel.py b/sglang-LenVM/python/sglang/srt/lvm/entropy_kernel.py new file mode 100644 index 0000000..74b3ac1 --- /dev/null +++ b/sglang-LenVM/python/sglang/srt/lvm/entropy_kernel.py @@ -0,0 +1,174 @@ +"""Triton helpers for LenVM entropy-based guidance skips.""" + +from __future__ import annotations + +import os +from typing import Optional + +import torch + + +_DISABLE_TRITON_ENTROPY = os.getenv("LVM_DISABLE_TRITON_ENTROPY", "").lower() in { + "1", + "true", + "yes", + "on", +} +_TRITON_AVAILABLE: Optional[bool] = None + + +def _fallback_entropy_from_probs( + probs: torch.Tensor, + rows: Optional[torch.Tensor] = None, +) -> torch.Tensor: + p = probs.index_select(0, rows) if rows is not None else probs + p = p.float() + s = p.sum(dim=-1) + p = p / s.clamp(min=1e-20).view(-1, 1) + return torch.special.entr(p).sum(dim=-1) + + +def _next_power_of_2(x: int) -> int: + return 1 << (max(int(x), 1) - 1).bit_length() + + +def full_vocab_entropy_from_probs( + probs: torch.Tensor, + rows: Optional[torch.Tensor] = None, + *, + block_size: int = 4096, +) -> torch.Tensor: + """Compute entropy over the full vocabulary distribution. + + The input is a probability tensor, not logits. To match the existing code + exactly enough for thresholding, rows are renormalized by their full-vocab + sum: + + H(p / sum(p)) = log(sum(p)) - sum(p * log(p)) / sum(p) + + Accumulation is FP32. This avoids materializing an extra selected + [n_rows, vocab] tensor and the intermediate `entr` tensor. + """ + + if ( + _DISABLE_TRITON_ENTROPY + or not probs.is_cuda + or probs.dim() != 2 + or probs.shape[-1] == 0 + ): + return _fallback_entropy_from_probs(probs, rows) + + global _TRITON_AVAILABLE + if _TRITON_AVAILABLE is False: + return _fallback_entropy_from_probs(probs, rows) + + try: + import triton + import triton.language as tl + except Exception: + _TRITON_AVAILABLE = False + return _fallback_entropy_from_probs(probs, rows) + + _TRITON_AVAILABLE = True + + n_total_rows = int(probs.shape[0]) + vocab_size = int(probs.shape[1]) + if rows is None: + n_rows = n_total_rows + rows_t = torch.empty((0,), dtype=torch.int64, device=probs.device) + has_rows = False + else: + rows_t = rows.to(device=probs.device, dtype=torch.int64) + n_rows = int(rows_t.numel()) + has_rows = True + if n_rows == 0: + return torch.empty((0,), dtype=torch.float32, device=probs.device) + + block_size = int(block_size) + if block_size <= 0: + block_size = 4096 + n_blocks = triton.cdiv(vocab_size, block_size) + reduce_block = _next_power_of_2(n_blocks) + + partial_sum = torch.empty((n_rows, n_blocks), dtype=torch.float32, device=probs.device) + partial_p_log_p = torch.empty_like(partial_sum) + out = torch.empty((n_rows,), dtype=torch.float32, device=probs.device) + + try: + _entropy_partial_kernel[(n_rows, n_blocks)]( + probs, + rows_t, + partial_sum, + partial_p_log_p, + vocab_size, + int(probs.stride(0)), + n_blocks, + HAS_ROWS=has_rows, + BLOCK_SIZE=block_size, + num_warps=8, + ) + _entropy_reduce_kernel[(n_rows,)]( + partial_sum, + partial_p_log_p, + out, + n_blocks, + BLOCK_N=reduce_block, + num_warps=8, + ) + return out + except Exception: + _TRITON_AVAILABLE = False + return _fallback_entropy_from_probs(probs, rows) + + +try: + import triton + import triton.language as tl + + @triton.jit + def _entropy_partial_kernel( + probs, + rows, + partial_sum, + partial_p_log_p, + vocab_size: tl.constexpr, + row_stride: tl.constexpr, + n_blocks: tl.constexpr, + HAS_ROWS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ): + row_pos = tl.program_id(0) + block_id = tl.program_id(1) + if HAS_ROWS: + row = tl.load(rows + row_pos) + else: + row = row_pos + offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < vocab_size + p = tl.load(probs + row * row_stride + offsets, mask=mask, other=0.0).to(tl.float32) + p_log_p = tl.where(p > 0.0, p * tl.log(p), 0.0) + out_offset = row_pos * n_blocks + block_id + tl.store(partial_sum + out_offset, tl.sum(p, axis=0)) + tl.store(partial_p_log_p + out_offset, tl.sum(p_log_p, axis=0)) + + @triton.jit + def _entropy_reduce_kernel( + partial_sum, + partial_p_log_p, + out, + n_blocks: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + row = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + mask = offsets < n_blocks + s = tl.load(partial_sum + row * n_blocks + offsets, mask=mask, other=0.0) + plogp = tl.load(partial_p_log_p + row * n_blocks + offsets, mask=mask, other=0.0) + sum_p = tl.sum(s, axis=0) + sum_p_log_p = tl.sum(plogp, axis=0) + ent = tl.where(sum_p > 0.0, tl.log(sum_p) - sum_p_log_p / sum_p, 0.0) + tl.store(out + row, ent) + +except Exception: + # Import-time fallback for CPU-only environments and Triton import issues. + _TRITON_AVAILABLE = False diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py index ea51495..50b616a 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_guided_sampling.py @@ -11,10 +11,12 @@ import requests import torch +from sglang.srt.lvm.entropy_kernel import full_vocab_entropy_from_probs from sglang.srt.sampling.sampling_params import TOP_K_ALL from sglang.srt.utils.common import dynamic_import from sglang.srt.server_args import get_global_server_args from sglang.srt.lvm.lvm_value_utils import force_eos_value_zero, get_eos_token_ids +from sglang.srt.lvm.timing import get_timer from sglang.srt.configs.model_config import ModelConfig from sglang.srt.managers.schedule_batch import Req @@ -758,9 +760,16 @@ class PendingLvmResult: # None means there is nothing to do (all rows were deterministic or skipped). guided: Optional[torch.Tensor] send_batch_indices: List[int] + deterministic_rows: List[tuple[int, int]] prefix_ids_send: List[List[int]] candidate_ids_send: List[List[int]] candidate_probs_send: List[List[float]] + deterministic_row_indices: Optional[torch.Tensor] = None + deterministic_token_ids: Optional[torch.Tensor] = None + fallback_sample_rows: Optional[torch.Tensor] = None + fallback_sample_probs: Optional[torch.Tensor] = None + fallback_sample_ids: Optional[torch.Tensor] = None + fallback_sample_mask: Optional[torch.Tensor] = None # GPU tensors for the fast path (only set when the guidance function can use the # expectation-guidance GPU path and all send indices come from the top-k path, # not top-k-all). @@ -769,6 +778,12 @@ class PendingLvmResult: gpu_candidates: Optional[tuple] = None +@dataclass +class LvmGuidedSampleResult: + row_indices: torch.Tensor + token_ids: torch.Tensor + + class LvmGuidedSampler: def __init__(self, config: LvmGuidedConfig, *, model_runner=None): self.config = config @@ -776,6 +791,17 @@ def __init__(self, config: LvmGuidedConfig, *, model_runner=None): self._fn = _load_guidance_fn(config.fn_spec) self._decode_model_runner = model_runner self._inproc = None + self._index_tensor_cache: dict[tuple[str, str, int], torch.Tensor] = {} + + def _cached_arange( + self, n: int, *, device: torch.device, dtype: torch.dtype = torch.long + ) -> torch.Tensor: + key = (str(device), str(dtype), int(n)) + cached = self._index_tensor_cache.get(key) + if cached is None or cached.device != device or cached.dtype != dtype or cached.numel() != n: + cached = torch.arange(n, device=device, dtype=dtype) + self._index_tensor_cache[key] = cached + return cached @staticmethod def from_server_args(server_args, model_runner=None) -> Optional["LvmGuidedSampler"]: @@ -1044,6 +1070,37 @@ def tree_value_launch_gpu( return embeddings # GPU tensor(s), not yet safe from default stream + def tree_value_launch_gpu_fused( + self, + rids: List[str], + prefix_ids: List[List[int]], + candidate_ids: List[List[int]], + reqs: List[Req], + gpu_candidates: Optional[tuple] = None, + ): + """Launch exact suffix+candidate tree-value evaluation in one LenVM forward.""" + if self.is_vlm: + # VLM prefix extension needs multimodal input handling. Keep the + # existing two-phase path there; the Qwen3 timing path is text-only. + self.tree_value_extend(rids, prefix_ids, reqs) + return self.tree_value_launch_gpu( + rids, + candidate_ids, + gpu_candidates=gpu_candidates, + ) + + self.lvm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.lvm_stream): + embeddings = self.incremental_runner.eval_prefix_and_candidates_batch_gpu( + rids, + prefix_ids, + candidate_ids, + gpu_candidates=gpu_candidates, + ) + self.embed_ready.record(self.lvm_stream) + + return embeddings + def tree_value_collect_gpu(self, gpu_embeddings): """Insert a stream dependency so the default stream waits for lvm_stream.""" torch.cuda.current_stream().wait_event(self.embed_ready) @@ -1152,9 +1209,17 @@ def _req_wants_value_guidance(req: Any) -> bool: We gate tree_value calls on this to avoid unnecessary network overhead when the user did not request value guidance (default scale/mode). """ + cached = getattr(req, "_lvm_wants_value_guidance", None) + if cached is not None: + return bool(cached) + sampling_params = getattr(req, "sampling_params", None) custom_params = getattr(sampling_params, "custom_params", None) if not isinstance(custom_params, dict): + try: + setattr(req, "_lvm_wants_value_guidance", False) + except Exception: + pass return False keys = ( "value_scale", @@ -1169,7 +1234,12 @@ def _req_wants_value_guidance(req: Any) -> bool: "cmp", "op", ) - return any(k in custom_params for k in keys) + result = any(k in custom_params for k in keys) + try: + setattr(req, "_lvm_wants_value_guidance", result) + except Exception: + pass + return result @staticmethod def _extract_entropy_threshold(req: Any) -> Optional[float]: @@ -1284,14 +1354,18 @@ def _build_pending( top_ps: torch.Tensor, top_ks: torch.Tensor, min_ps: torch.Tensor, + *, + build_guided: bool = True, ) -> Optional["PendingLvmResult"]: """Filter candidates and build a PendingLvmResult without contacting the LVM. Returns None when there is nothing to do (no rows want guidance). """ + timer = get_timer() device = probs.device vocab_size = int(probs.shape[-1]) + t_guided_rows = timer.section_start("t_lvm_build_guided_rows_ms") # `reqs` is typically already a list, but may be any iterable. req_list = reqs if isinstance(reqs, list) else list(reqs) @@ -1300,78 +1374,332 @@ def _build_pending( for i, req in enumerate(req_list): if self._req_wants_value_guidance(req): guided_rows.append(i) + timer.section_end("t_lvm_build_guided_rows_ms", t_guided_rows) + timer.set_meta(lvm_guided_rows=len(guided_rows)) # If nobody requested value guidance, do nothing and let normal sampling proceed. if not guided_rows: return None - guided_rows_t = torch.tensor(guided_rows, device=device, dtype=torch.long) - - # Gather per-row sampling params for the guided subset (single gather, no per-row .item()). - top_ks_sel = top_ks.index_select(0, guided_rows_t).to(torch.int64) - top_ps_sel = top_ps.index_select(0, guided_rows_t).to(torch.float32) - min_ps_sel = min_ps.index_select(0, guided_rows_t).to(torch.float32) + t_guided_tensor = timer.section_start("t_lvm_build_guided_tensor_ms") + all_rows_guided = len(guided_rows) == len(req_list) + if all_rows_guided: + guided_rows_t = self._cached_arange( + len(req_list), device=device, dtype=torch.long + ) + else: + guided_rows_t = torch.tensor(guided_rows, device=device, dtype=torch.long) + timer.section_end("t_lvm_build_guided_tensor_ms", t_guided_tensor) + + t_params = timer.section_start("t_lvm_build_param_gather_ms") + # Prefer CPU request metadata for sampling params. Pulling tiny values from + # GPU tensors with .item() would synchronize the preceding full-vocab softmax. + top_ks_cpu: List[int] = [] + top_ps_cpu: List[float] = [] + min_ps_cpu: List[float] = [] + use_req_sampling_params = True + for ridx in guided_rows: + sp = getattr(req_list[ridx], "sampling_params", None) + if ( + sp is None + or not hasattr(sp, "top_k") + or not hasattr(sp, "top_p") + or not hasattr(sp, "min_p") + ): + use_req_sampling_params = False + break + top_k_i = int(getattr(sp, "top_k")) + if top_k_i == -1: + top_k_i = TOP_K_ALL + top_ks_cpu.append(top_k_i) + top_ps_cpu.append(float(getattr(sp, "top_p"))) + min_ps_cpu.append(float(getattr(sp, "min_p"))) + + top_ks_sel: Optional[torch.Tensor] = None + top_ps_sel: Optional[torch.Tensor] = None + min_ps_sel: Optional[torch.Tensor] = None + mask_all: Optional[torch.Tensor] = None + mask_topk: Optional[torch.Tensor] = None + mask_all_cpu: List[bool] = [] + mask_topk_cpu: List[bool] = [] + if use_req_sampling_params: + mask_all_cpu = [k == TOP_K_ALL for k in top_ks_cpu] + mask_topk_cpu = [not x for x in mask_all_cpu] + has_topk = any(mask_topk_cpu) + has_all = any(mask_all_cpu) + k_max_cpu = 0 + if has_topk: + k_max_cpu = min( + max(k for k, is_topk in zip(top_ks_cpu, mask_topk_cpu) if is_topk), + vocab_size, + ) + else: + # Fallback for non-Req callers: gather per-row sampling params from tensors. + if all_rows_guided: + top_ks_sel = top_ks.to(torch.int64) + top_ps_sel = top_ps.to(torch.float32) + min_ps_sel = min_ps.to(torch.float32) + else: + top_ks_sel = top_ks.index_select(0, guided_rows_t).to(torch.int64) + top_ps_sel = top_ps.index_select(0, guided_rows_t).to(torch.float32) + min_ps_sel = min_ps.index_select(0, guided_rows_t).to(torch.float32) + mask_all = top_ks_sel == TOP_K_ALL + mask_topk = ~mask_all + has_topk = bool(mask_topk.any().item()) + has_all = bool(mask_all.any().item()) + k_max_cpu = 0 prefix_ids_send: List[List[int]] = [] candidate_ids_send: List[List[int]] = [] send_batch_indices: List[int] = [] candidate_probs_send: List[List[float]] = [] deterministic_rows: List[tuple[int, int]] = [] - - # Split rare slow-path (top_k == ALL) from the common (top_k is small). - mask_all = top_ks_sel == TOP_K_ALL - mask_topk = ~mask_all - + deterministic_row_indices_t: Optional[torch.Tensor] = None + deterministic_token_ids_t: Optional[torch.Tensor] = None # GPU fast path is active when we are doing expectation-style guidance # without any hard target_value/target_length constraints, and all sequences # use top-k (not top-k-all), so we can keep tensors on GPU. has_hard_target = any( _extract_target_value({"req": req_list[ridx]}) is not None - for ridx in range(len(req_list)) + for ridx in guided_rows ) _use_gpu_path = self._fn in ( lvm_expectation_guidance, lvm_combined_guidance, - ) and not bool(mask_all.any().item()) and not has_hard_target + ) and not has_all and not has_hard_target + timer.set_meta( + lvm_has_topk_path=has_topk, + lvm_has_topk_all_path=has_all, + lvm_gpu_guidance_path=_use_gpu_path, + ) + timer.section_end("t_lvm_build_param_gather_ms", t_params) # Will hold (vals_send_gpu, idx_send_gpu) for the top-k send rows, on GPU. _gpu_vals_chunks: List[torch.Tensor] = [] _gpu_idx_chunks: List[torch.Tensor] = [] + _fallback_rows_chunks: List[torch.Tensor] = [] + _fallback_vals_chunks: List[torch.Tensor] = [] + _fallback_idx_chunks: List[torch.Tensor] = [] + # Fast path: batched full-vocab entropy -> top-k -> top-p/min-p. # --------------------------- - # Fast path: batched top-k -> top-p/min-p/entropy on the top-k subset. - # --------------------------- - if bool(mask_topk.any().item()): - rows_topk_t = guided_rows_t[mask_topk] - top_ks_k = top_ks_sel[mask_topk].clamp(min=1, max=vocab_size) - top_ps_k = top_ps_sel[mask_topk] - min_ps_k = min_ps_sel[mask_topk] - - # Single synchronization point to get k_max. - k_max = int(top_ks_k.max().item()) + if has_topk: + rows_topk_list_cpu: Optional[List[int]] = None + pre_entropy_handled_cpu: Optional[List[bool]] = None + if use_req_sampling_params: + topk_positions_cpu = [ + j for j, is_topk in enumerate(mask_topk_cpu) if is_topk + ] + rows_topk_list_cpu = [guided_rows[j] for j in topk_positions_cpu] + if has_all: + topk_positions_t = torch.tensor( + topk_positions_cpu, device=device, dtype=torch.long + ) + rows_topk_t = guided_rows_t.index_select(0, topk_positions_t) + else: + rows_topk_t = guided_rows_t + top_ks_topk_cpu = [ + min(max(int(top_ks_cpu[j]), 1), vocab_size) + for j in topk_positions_cpu + ] + top_ps_topk_cpu = [float(top_ps_cpu[j]) for j in topk_positions_cpu] + min_ps_topk_cpu = [float(min_ps_cpu[j]) for j in topk_positions_cpu] + else: + rows_topk_t = guided_rows_t[mask_topk] + top_ks_k = top_ks_sel[mask_topk].clamp(min=1, max=vocab_size) + top_ps_k = top_ps_sel[mask_topk] + min_ps_k = min_ps_sel[mask_topk] + + # Sparse sampling path: entropy skip does not need candidate filtering. + # Evaluate it before top-k so rows that will fall back to the native + # sampler avoid a full-vocab top-k scan. Keep this path conservative: + # only pre-skip rows where top-k/top-p/min-p cannot make the row + # deterministic; other rows use the existing post-filter check below. + if ( + (not build_guided) + and use_req_sampling_params + and rows_topk_list_cpu + ): + t_entropy_params = timer.section_start("t_lvm_build_entropy_params_ms") + pre_thr_list: List[float] = [] + pre_has_thr_list: List[bool] = [] + for ridx in rows_topk_list_cpu: + thr = self._extract_entropy_threshold(req_list[ridx]) + if thr is None: + pre_thr_list.append(float("nan")) + pre_has_thr_list.append(False) + else: + pre_thr_list.append(float(thr)) + pre_has_thr_list.append(True) + pre_has_any_thr = any(pre_has_thr_list) + pre_valid_thrs = [ + thr for thr, has_thr in zip(pre_thr_list, pre_has_thr_list) if has_thr + ] + use_scalar_pre_entropy = ( + pre_has_any_thr + and all(pre_has_thr_list) + and len(set(pre_valid_thrs)) == 1 + and len(set(top_ks_topk_cpu)) == 1 + and len(set(top_ps_topk_cpu)) == 1 + and len(set(min_ps_topk_cpu)) == 1 + ) + pre_thr_t = None + pre_has_thr = None + pre_top_ks_t = None + pre_top_ps_t = None + pre_min_ps_t = None + pre_thr_scalar = float(pre_valid_thrs[0]) if use_scalar_pre_entropy else None + if not use_scalar_pre_entropy and pre_has_any_thr: + pre_thr_t = torch.tensor(pre_thr_list, device=device, dtype=torch.float32) + pre_has_thr = torch.tensor(pre_has_thr_list, device=device, dtype=torch.bool) + pre_top_ks_t = torch.tensor(top_ks_topk_cpu, device=device, dtype=torch.int64) + pre_top_ps_t = torch.tensor(top_ps_topk_cpu, device=device, dtype=torch.float32) + pre_min_ps_t = torch.tensor(min_ps_topk_cpu, device=device, dtype=torch.float32) + timer.section_end("t_lvm_build_entropy_params_ms", t_entropy_params) + + t_entropy_eval = timer.section_start("t_lvm_build_entropy_eval_ms") + pre_skip = torch.zeros( + (len(rows_topk_list_cpu),), device=device, dtype=torch.bool + ) + pre_handled = torch.zeros_like(pre_skip) + if pre_has_any_thr: + if all_rows_guided and not has_all: + max_prob = probs.max(dim=-1).values.float().clamp(min=1e-20) + else: + max_prob = probs.index_select(0, rows_topk_t).max(dim=-1).values.float().clamp(min=1e-20) + # With min_p disabled and max_prob <= top_p, top-p keeps at + # least two candidates when top_k > 1, so the row cannot be + # deterministic under the later candidate filter. + entropy_lb = -torch.log(max_prob) + if use_scalar_pre_entropy: + if int(top_ks_topk_cpu[0]) > 1 and float(min_ps_topk_cpu[0]) <= 0.0: + can_precheck = max_prob <= float(top_ps_topk_cpu[0]) + else: + can_precheck = torch.zeros_like(max_prob, dtype=torch.bool) + exact_mask = can_precheck & (entropy_lb <= pre_thr_scalar) + else: + can_precheck = ( + pre_has_thr + & (pre_top_ks_t > 1) + & (pre_min_ps_t <= 0.0) + & (max_prob <= pre_top_ps_t) + ) + exact_mask = can_precheck & (entropy_lb <= pre_thr_t) + entropy_rows = torch.nonzero(exact_mask, as_tuple=False).view(-1) + if entropy_rows.numel() > 0: + ent = full_vocab_entropy_from_probs( + probs, rows_topk_t.index_select(0, entropy_rows) + ) + if use_scalar_pre_entropy: + pre_skip[entropy_rows] = ent <= pre_thr_scalar + else: + pre_skip[entropy_rows] = ent <= pre_thr_t.index_select(0, entropy_rows) + pre_handled = can_precheck + timer.section_end("t_lvm_build_entropy_eval_ms", t_entropy_eval) + + pre_handled_cpu = pre_handled.detach().cpu().tolist() + pre_skip_cpu = pre_skip.detach().cpu().tolist() + + if any(pre_skip_cpu): + keep_pre = ~pre_skip + keep_pre_cpu = [not x for x in pre_skip_cpu] + if not any(keep_pre_cpu): + # Keep one row on the normal post-filter path so the + # surrounding top-k code remains simple. That row will + # still get the exact entropy decision below. + keep_pre_cpu[-1] = True + pre_skip[-1] = False + pre_handled[-1] = False + pre_skip_cpu[-1] = False + pre_handled_cpu[-1] = False + keep_pre = ~pre_skip + rows_topk_t = rows_topk_t[keep_pre] + rows_topk_list_cpu = [ + ridx for ridx, keep in zip(rows_topk_list_cpu, keep_pre_cpu) if keep + ] + top_ks_topk_cpu = [ + val for val, keep in zip(top_ks_topk_cpu, keep_pre_cpu) if keep + ] + top_ps_topk_cpu = [ + val for val, keep in zip(top_ps_topk_cpu, keep_pre_cpu) if keep + ] + min_ps_topk_cpu = [ + val for val, keep in zip(min_ps_topk_cpu, keep_pre_cpu) if keep + ] + pre_entropy_handled_cpu = [ + handled + for handled, keep in zip(pre_handled_cpu, keep_pre_cpu) + if keep + ] + k_max_cpu = min(max(top_ks_topk_cpu), vocab_size) + elif any(pre_handled_cpu): + pre_entropy_handled_cpu = pre_handled_cpu + if getattr(timer, "enabled", False) and pre_skip_cpu: + timer.set_meta(lvm_pre_entropy_skipped_rows=sum(int(x) for x in pre_skip_cpu)) + + t_topk = timer.section_start("t_lvm_build_topk_filter_ms") + if k_max_cpu > 0: + k_max = k_max_cpu + else: + # Fallback synchronization only when CPU sampling params are unavailable. + k_max = int(top_ks_k.max().item()) if k_max <= 0: + timer.section_end("t_lvm_build_topk_filter_ms", t_topk) return None - probs_k = probs.index_select(0, rows_topk_t).float() - topk_vals, topk_idx = torch.topk(probs_k, k_max, dim=-1) # sorted desc + if all_rows_guided and not has_all and rows_topk_t.numel() == probs.shape[0]: + topk_source = probs + topk_vals, topk_idx = torch.topk(topk_source, k=k_max, dim=-1) + else: + topk_source = probs.index_select(0, rows_topk_t) + topk_vals, topk_idx = torch.topk(topk_source, k=k_max, dim=-1) # Apply per-row top-k (mask out positions >= top_k_i). - ar = torch.arange(k_max, device=device, dtype=torch.int64).view(1, -1) - keep_k = ar < top_ks_k.view(-1, 1) - vals = torch.where(keep_k, topk_vals, torch.zeros_like(topk_vals)) + if use_req_sampling_params and all(k == k_max for k in top_ks_topk_cpu): + vals = topk_vals + else: + if use_req_sampling_params: + top_ks_k = torch.tensor( + top_ks_topk_cpu, device=device, dtype=torch.int64 + ) + ar = self._cached_arange(k_max, device=device, dtype=torch.int64).view(1, -1) + keep_k = ar < top_ks_k.view(-1, 1) + vals = torch.where(keep_k, topk_vals, torch.zeros_like(topk_vals)) # Apply per-row top-p within the (masked) top-k list. # Keep token j if sum(vals[:j]) <= top_p (equivalently (cum - val) <= top_p). # Note: vals are already sorted descending before masking. keep_p = torch.ones_like(vals, dtype=torch.bool) - if torch.any(top_ps_k < 1.0): + if use_req_sampling_params: + apply_top_p = any(p < 1.0 for p in top_ps_topk_cpu) + else: + apply_top_p = bool(torch.any(top_ps_k < 1.0).item()) + if apply_top_p: cum = torch.cumsum(vals, dim=-1) - keep_p = (cum - vals) <= top_ps_k.view(-1, 1) + if use_req_sampling_params and len(set(top_ps_topk_cpu)) == 1: + keep_p = (cum - vals) <= float(top_ps_topk_cpu[0]) + else: + if use_req_sampling_params: + top_ps_k = torch.tensor( + top_ps_topk_cpu, device=device, dtype=torch.float32 + ) + keep_p = (cum - vals) <= top_ps_k.view(-1, 1) vals = torch.where(keep_p, vals, torch.zeros_like(vals)) # Apply per-row min-p: keep tokens with prob >= max_prob * min_p. - if torch.any(min_ps_k > 0.0): + if use_req_sampling_params: + apply_min_p = any(p > 0.0 for p in min_ps_topk_cpu) + else: + apply_min_p = bool(torch.any(min_ps_k > 0.0).item()) + if apply_min_p: max_prob = vals.max(dim=-1).values - thresh = max_prob * min_ps_k + if use_req_sampling_params and len(set(min_ps_topk_cpu)) == 1: + thresh = max_prob * float(min_ps_topk_cpu[0]) + else: + if use_req_sampling_params: + min_ps_k = torch.tensor( + min_ps_topk_cpu, device=device, dtype=torch.float32 + ) + thresh = max_prob * min_ps_k keep_min = vals >= thresh.view(-1, 1) vals = torch.where(keep_min, vals, torch.zeros_like(vals)) @@ -1387,57 +1715,120 @@ def _build_pending( # Deterministic rows: exactly 1 candidate. det_mask = counts == 1 - det_pos = torch.argmax(vals, dim=-1) - det_token_ids = torch.gather( - topk_idx, dim=1, index=det_pos.view(-1, 1) - ).view(-1) + det_token_ids = None + if build_guided: + det_pos = torch.argmax(vals, dim=-1) + det_token_ids = torch.gather( + topk_idx, dim=1, index=det_pos.view(-1, 1) + ).view(-1) + timer.section_end("t_lvm_build_topk_filter_ms", t_topk) # Optional entropy-based skip (per request, Python-sourced thresholds). - rows_topk_list: List[int] = rows_topk_t.detach().cpu().tolist() + t_entropy_params = timer.section_start("t_lvm_build_entropy_params_ms") + if use_req_sampling_params: + rows_topk_list = rows_topk_list_cpu or [] + elif not has_all: + rows_topk_list = guided_rows + else: + rows_topk_list = rows_topk_t.detach().cpu().tolist() thr_list: List[float] = [] - has_thr = torch.zeros(len(rows_topk_list), device=device, dtype=torch.bool) + has_thr_list: List[bool] = [] for j, ridx in enumerate(rows_topk_list): - thr = self._extract_entropy_threshold(req_list[ridx]) + if ( + pre_entropy_handled_cpu is not None + and j < len(pre_entropy_handled_cpu) + and pre_entropy_handled_cpu[j] + ): + thr = None + else: + thr = self._extract_entropy_threshold(req_list[ridx]) if thr is None: thr_list.append(float("nan")) + has_thr_list.append(False) else: thr_list.append(float(thr)) - has_thr[j] = True - - # Use float64 for value-guidance gating to avoid precision loss in entropy comparisons. - thr_t = torch.tensor(thr_list, device=device, dtype=torch.float64) - skip_entropy = torch.zeros_like(has_thr, dtype=torch.bool) - if bool(has_thr.any().item()): - p = vals.to(torch.float64) - s = p.sum(dim=-1) - # Avoid division by 0; counts==0 already fixed. - p = p / s.clamp(min=1e-20).view(-1, 1) - ent = -(p * torch.log(p + 1e-20)).sum(dim=-1) - skip_entropy = has_thr & (ent <= thr_t) + has_thr_list.append(True) + + # Compute entropy on the full vocabulary distribution, before top-k/top-p/min-p + # candidate filtering. The skip decision should reflect whether the base + # model is globally confident, while LenVM still only scores the filtered + # candidate set below. + has_thr_any = any(has_thr_list) + valid_thrs = [ + thr for thr, has_thr_i in zip(thr_list, has_thr_list) if has_thr_i + ] + post_thr_scalar = ( + float(valid_thrs[0]) + if has_thr_any and all(has_thr_list) and len(set(valid_thrs)) == 1 + else None + ) + thr_t = None + has_thr = None + if has_thr_any and post_thr_scalar is None: + thr_t = torch.tensor(thr_list, device=device, dtype=torch.float32) + has_thr = torch.tensor(has_thr_list, device=device, dtype=torch.bool) + timer.section_end("t_lvm_build_entropy_params_ms", t_entropy_params) + + t_entropy_eval = timer.section_start("t_lvm_build_entropy_eval_ms") + skip_entropy = torch.zeros_like(det_mask, dtype=torch.bool) + entropy_rows = torch.empty((0,), device=device, dtype=torch.long) + if has_thr_any: + # Exact safe shortcut: H(p) >= -log(max_i p_i). If that lower + # bound already exceeds the threshold, the row cannot be skipped + # and we avoid a full-vocab entropy pass for that row. + max_prob = topk_vals[:, 0].float().clamp(min=1e-20) + entropy_lb = -torch.log(max_prob) + if post_thr_scalar is not None: + entropy_mask = ~det_mask + exact_entropy_mask = entropy_mask & (entropy_lb <= post_thr_scalar) + entropy_rows = torch.nonzero(exact_entropy_mask, as_tuple=False).view(-1) + else: + entropy_mask = has_thr & (~det_mask) + exact_entropy_mask = entropy_mask & (entropy_lb <= thr_t) + entropy_rows = torch.nonzero(exact_entropy_mask, as_tuple=False).view(-1) + if entropy_rows.numel() > 0: + ent = full_vocab_entropy_from_probs(probs, rows_topk_t.index_select(0, entropy_rows)) + if post_thr_scalar is not None: + skip_entropy[entropy_rows] = ent <= post_thr_scalar + else: + skip_entropy[entropy_rows] = ent <= thr_t.index_select(0, entropy_rows) + timer.section_end("t_lvm_build_entropy_eval_ms", t_entropy_eval) # Rows to send to LVM: non-deterministic and not skipped by entropy. send_mask = (~det_mask) & (~skip_entropy) - - # Materialize deterministic rows in Python list. - if bool(det_mask.any().item()): - det_token_ids_cpu = det_token_ids.detach().cpu().tolist() - det_mask_cpu = det_mask.detach().cpu().tolist() - for j, is_det in enumerate(det_mask_cpu): - if is_det: - deterministic_rows.append((rows_topk_list[j], int(det_token_ids_cpu[j]))) + fallback_sample_mask = skip_entropy + + t_det = timer.section_start("t_lvm_build_deterministic_ms") + if build_guided and bool(det_mask.any().item()): + if det_token_ids is None: + raise RuntimeError("Missing deterministic token ids for LenVM deterministic rows") + deterministic_row_indices_t = rows_topk_t[det_mask].to(torch.long) + deterministic_token_ids_t = det_token_ids[det_mask].to(torch.int32) + timer.section_end("t_lvm_build_deterministic_ms", t_det) + + # Rows skipped by entropy still need to be sampled from the same + # top-k/top-p/min-p distribution. Keep those tensors so the caller can + # avoid running the normal sampler a second time for fully guided batches. + t_fallback = timer.section_start("t_lvm_build_fallback_cache_ms") + if build_guided and bool(fallback_sample_mask.any().item()): + _fallback_rows_chunks.append(rows_topk_t[fallback_sample_mask]) + _fallback_vals_chunks.append(vals[fallback_sample_mask]) + _fallback_idx_chunks.append(topk_idx[fallback_sample_mask]) + timer.section_end("t_lvm_build_fallback_cache_ms", t_fallback) # Prepare candidate lists for rows that we will actually send. + t_send = timer.section_start("t_lvm_build_send_cpu_ms") if bool(send_mask.any().item()): rows_send_t = rows_topk_t[send_mask] # Keep GPU slices before moving to CPU (used by GPU guidance fast path). vals_send_gpu = vals[send_mask] # [B_topk_send, K_max], GPU idx_send_gpu = topk_idx[send_mask] # [B_topk_send, K_max], GPU - idx_send = idx_send_gpu.detach().cpu() # GPU fast path: only need bool mask (4x smaller than float32 transfer). # CPU path: need full float values for candidate_probs_send. if _use_gpu_path: valid_mask_send = (vals_send_gpu > 0).detach().cpu() else: + idx_send = idx_send_gpu.detach().cpu() vals_send = vals_send_gpu.detach().cpu() rows_send_list = rows_send_t.detach().cpu().tolist() @@ -1445,13 +1836,20 @@ def _build_pending( # In practice (sorted desc + thresholding), non-zeros are a prefix. Still, use mask for safety. if _use_gpu_path: m = valid_mask_send[j] + n_cand = int(m.sum().item()) + cand_ids = [0] * n_cand else: m = vals_send[j] > 0 - cand_ids = idx_send[j][m].tolist() - if len(cand_ids) <= 1: + cand_ids = idx_send[j][m].tolist() + n_cand = len(cand_ids) + if n_cand <= 1: # Should have been caught by det_mask, but keep a safe fallback. - if len(cand_ids) == 1: - deterministic_rows.append((ridx, int(cand_ids[0]))) + if n_cand == 1: + if _use_gpu_path: + tok = int(idx_send_gpu[j][m.to(device=idx_send_gpu.device)].detach().cpu()[0].item()) + else: + tok = int(cand_ids[0]) + deterministic_rows.append((ridx, tok)) continue prefix = self._get_prefix_ids_incremental(req_list[ridx]) @@ -1464,14 +1862,27 @@ def _build_pending( # Capture GPU row j for later gpu_candidates assembly. _gpu_vals_chunks.append(vals_send_gpu[j].unsqueeze(0)) _gpu_idx_chunks.append(idx_send_gpu[j].unsqueeze(0)) + timer.section_end("t_lvm_build_send_cpu_ms", t_send) # --------------------------- # Slow path: top_k == ALL (full vocab filtering). Rare; keep correctness-oriented CPU behavior. # --------------------------- - if bool(mask_all.any().item()): - rows_all_list = guided_rows_t[mask_all].detach().cpu().tolist() - top_ps_all = top_ps_sel[mask_all].detach().cpu().tolist() - min_ps_all = min_ps_sel[mask_all].detach().cpu().tolist() + if has_all: + t_slow = timer.section_start("t_lvm_build_slow_all_ms") + if use_req_sampling_params: + rows_all_list = [ + ridx for ridx, is_all in zip(guided_rows, mask_all_cpu) if is_all + ] + top_ps_all = [ + float(p) for p, is_all in zip(top_ps_cpu, mask_all_cpu) if is_all + ] + min_ps_all = [ + float(p) for p, is_all in zip(min_ps_cpu, mask_all_cpu) if is_all + ] + else: + rows_all_list = guided_rows_t[mask_all].detach().cpu().tolist() + top_ps_all = top_ps_sel[mask_all].detach().cpu().tolist() + min_ps_all = min_ps_sel[mask_all].detach().cpu().tolist() for j, i in enumerate(rows_all_list): row = probs[i] top_p_i = float(top_ps_all[j]) @@ -1492,12 +1903,13 @@ def _build_pending( thr = self._extract_entropy_threshold(req_list[i]) if thr is not None: - # float64 for stable entropy computation/comparison - p = torch.tensor(cand_probs, dtype=torch.float64) + # Match the fast path: entropy is measured on the full vocab + # distribution, not on the post-filtered candidate list. + p = row_cpu.to(torch.float64) s = float(p.sum().item()) if s > 0: p = p / s - ent = float(-(p * torch.log(p + 1e-20)).sum().item()) + ent = float(torch.special.entr(p).sum().item()) else: ent = 0.0 if ent <= float(thr): @@ -1508,17 +1920,29 @@ def _build_pending( candidate_ids_send.append(cand_idx) candidate_probs_send.append(cand_probs) send_batch_indices.append(i) + timer.section_end("t_lvm_build_slow_all_ms", t_slow) - if not send_batch_indices and not deterministic_rows: + t_assemble = timer.section_start("t_lvm_build_assemble_ms") + has_tensor_deterministic = ( + deterministic_row_indices_t is not None + and deterministic_row_indices_t.numel() > 0 + ) + if not send_batch_indices and not deterministic_rows and not has_tensor_deterministic: + timer.section_end("t_lvm_build_assemble_ms", t_assemble) return None - # Build guided tensor and fill deterministic rows immediately. - guided = probs.clone() + guided = None + if build_guided: + # Build guided tensor and fill deterministic rows immediately. + guided = probs.clone() - # Fill deterministic rows (single candidate) without contacting LVM. - for i, tok in deterministic_rows: - guided[i].zero_() - guided[i, tok] = 1.0 + # Fill deterministic rows (single candidate) without contacting LVM. + if has_tensor_deterministic: + guided.index_fill_(0, deterministic_row_indices_t, 0.0) + guided[deterministic_row_indices_t, deterministic_token_ids_t.to(torch.long)] = 1.0 + for i, tok in deterministic_rows: + guided[i].zero_() + guided[i, tok] = 1.0 # Assemble GPU candidate tensors for the fast guidance path. gpu_candidates = None @@ -1528,23 +1952,41 @@ def _build_pending( gm = gp > 0 # [B_send, K_max] bool gpu_candidates = (gp, gi, gm) + fallback_sample_rows = None + fallback_sample_probs = None + fallback_sample_ids = None + fallback_sample_mask = None + if _fallback_rows_chunks: + fallback_sample_rows = torch.cat(_fallback_rows_chunks, dim=0).to(torch.long) + fallback_sample_probs = torch.cat(_fallback_vals_chunks, dim=0).float() + fallback_sample_ids = torch.cat(_fallback_idx_chunks, dim=0) + fallback_sample_mask = fallback_sample_probs > 0 + + timer.section_end("t_lvm_build_assemble_ms", t_assemble) return PendingLvmResult( req_list=req_list, device=device, guided=guided, send_batch_indices=send_batch_indices, + deterministic_rows=deterministic_rows, prefix_ids_send=prefix_ids_send, candidate_ids_send=candidate_ids_send, candidate_probs_send=candidate_probs_send, + deterministic_row_indices=deterministic_row_indices_t, + deterministic_token_ids=deterministic_token_ids_t, + fallback_sample_rows=fallback_sample_rows, + fallback_sample_probs=fallback_sample_probs, + fallback_sample_ids=fallback_sample_ids, + fallback_sample_mask=fallback_sample_mask, gpu_candidates=gpu_candidates, ) - def _apply_guidance_gpu(self, pending: "PendingLvmResult", gpu_embeddings) -> None: - """GPU-native guidance path for lvm_expectation_guidance. + def _guided_candidate_probs_gpu(self, pending: "PendingLvmResult", gpu_embeddings): + """Compute GPU-native LenVM-reweighted probabilities over top-k candidates. - Runs sigmoid + batched Newton's method entirely on GPU. No CPU↔GPU - data movement after this point — results are scattered directly into - pending.guided. + Runs sigmoid + batched guidance entirely on GPU and returns + (final_probs, padded_ids, valid_mask), each shaped [B_send, K_max]. + No full-vocab probability tensor is constructed. gpu_embeddings: GPU tensor(s) of raw LVM scalar outputs. - If a list of 1D tensors (variable-length per sequence): padded internally. @@ -1699,7 +2141,19 @@ def _apply_guidance_gpu(self, pending: "PendingLvmResult", gpu_embeddings) -> No else: raise ValueError(f"[LVM GPU path] Unknown value_mode: {mode!r}") + return final_probs.masked_fill(~valid_mask, 0.0), padded_ids, valid_mask + + def _apply_guidance_gpu(self, pending: "PendingLvmResult", gpu_embeddings) -> None: + """GPU-native guidance path for lvm_expectation_guidance.""" + final_probs, padded_ids, valid_mask = self._guided_candidate_probs_gpu( + pending, gpu_embeddings + ) + K_max = final_probs.shape[1] + device = final_probs.device + guided = pending.guided + if guided is None: + raise RuntimeError("LenVM guidance requested a probability tensor but none was built") rows_t = torch.tensor(pending.send_batch_indices, device=device, dtype=torch.long) # Zero target rows in one kernel (avoids bs individual .zero_() calls). guided.index_fill_(0, rows_t, 0.0) @@ -1751,6 +2205,189 @@ def _apply_guidance( guided[i].zero_() guided[i, torch.tensor(token_ids, device=device)] = new_probs_tensor + @staticmethod + def _sample_rows_from_candidate_probs( + probs: torch.Tensor, + token_ids: torch.Tensor, + valid_mask: torch.Tensor, + row_indices: torch.Tensor, + sampling_seed: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + ) -> torch.Tensor: + probs = probs.float() + probs = probs.masked_fill(~valid_mask, 0.0) + probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-20) + if sampling_seed is None or positions is None: + sampled_pos = torch.multinomial(probs, num_samples=1).view(-1) + return torch.gather(token_ids, 1, sampled_pos.view(-1, 1)).view(-1).to(torch.int32) + + seed = sampling_seed.index_select(0, row_indices).to(torch.int64) + pos = positions.index_select(0, row_indices).to(torch.int64) + step_seed = (seed * 19349663) ^ (pos * 73856093) + hashed = (step_seed.unsqueeze(-1) * 8589934591) ^ ( + token_ids.to(torch.int64) * 479001599 + ) + uniform_samples = (hashed % (2**24)).float() / (2**24) + uniform_samples = uniform_samples.clamp(1e-10, 1.0 - 1e-10) + gumbel_noise = -torch.log(-torch.log(uniform_samples)) + log_probs = torch.log(probs.clamp(min=1e-20)).masked_fill(~valid_mask, -1e20) + sampled_pos = torch.argmax(log_probs + gumbel_noise, dim=1) + return torch.gather(token_ids, 1, sampled_pos.view(-1, 1)).view(-1).to(torch.int32) + + def _sample_from_pending( + self, + pending: "PendingLvmResult", + reqs: Iterable[Any], + sampling_seed: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + ) -> Optional[LvmGuidedSampleResult]: + timer = get_timer() + timer.set_meta( + lvm_n_reqs_with_guidance=len(pending.send_batch_indices), + lvm_n_candidates=sum(len(x) for x in pending.candidate_ids_send), + lvm_has_deterministic=bool(pending.deterministic_rows) + or ( + pending.deterministic_row_indices is not None + and pending.deterministic_row_indices.numel() > 0 + ), + lvm_has_fallback_sample=pending.fallback_sample_rows is not None, + ) + + t_provider = timer.section_start("t_lvm_get_inproc_ms") + inproc = self._get_inproc_provider() + timer.section_end("t_lvm_get_inproc_ms", t_provider) + if inproc not in (None, False): + t_clean = timer.section_start("t_lvm_clean_stale_ms") + inproc.clean_stale_requests(set(r.rid for r in reqs)) + timer.section_end("t_lvm_clean_stale_ms", t_clean) + + device = pending.device + row_parts: List[torch.Tensor] = [] + token_parts: List[torch.Tensor] = [] + + # In apply() mode we construct a full guided probability tensor, so rows + # that became deterministic/fallback inside candidate filtering must be + # handled here. In sample_token_ids() mode we only return sparse LenVM + # overrides; native SGLang sampling handles deterministic and entropy- + # skipped rows exactly once. + emit_non_lvm_rows = pending.guided is not None + + if emit_non_lvm_rows and pending.deterministic_rows: + t_det = timer.section_start("t_lvm_deterministic_sample_ms") + det_rows, det_toks = zip(*pending.deterministic_rows) + row_parts.append(torch.tensor(det_rows, device=device, dtype=torch.long)) + token_parts.append(torch.tensor(det_toks, device=device, dtype=torch.int32)) + timer.section_end("t_lvm_deterministic_sample_ms", t_det) + if emit_non_lvm_rows and pending.deterministic_row_indices is not None: + t_det = timer.section_start("t_lvm_deterministic_sample_ms") + row_parts.append(pending.deterministic_row_indices.to(device=device, dtype=torch.long)) + if pending.deterministic_token_ids is None: + raise RuntimeError("Missing deterministic token ids for LenVM deterministic rows") + token_parts.append(pending.deterministic_token_ids.to(device=device, dtype=torch.int32)) + timer.section_end("t_lvm_deterministic_sample_ms", t_det) + + if emit_non_lvm_rows and pending.fallback_sample_rows is not None: + t_fallback = timer.section_start("t_lvm_fallback_sample_ms") + fallback_token_ids = self._sample_rows_from_candidate_probs( + pending.fallback_sample_probs, + pending.fallback_sample_ids, + pending.fallback_sample_mask, + pending.fallback_sample_rows, + sampling_seed, + positions, + ) + row_parts.append(pending.fallback_sample_rows) + token_parts.append(fallback_token_ids) + timer.section_end("t_lvm_fallback_sample_ms", t_fallback) + + if pending.send_batch_indices: + reqs_send = [pending.req_list[i] for i in pending.send_batch_indices] + if pending.gpu_candidates is not None and inproc not in (None, False): + try: + rids_send = [req.rid for req in reqs_send] + timer.set_meta(lvm_forward_rows=len(rids_send)) + t_forward = timer.section_start("t_lvm_forward_ms") + gpu_emb = inproc.tree_value_launch_gpu_fused( + rids_send, + pending.prefix_ids_send, + pending.candidate_ids_send, + reqs_send, + gpu_candidates=pending.gpu_candidates, + ) + gpu_embeddings = inproc.tree_value_collect_gpu(gpu_emb) + timer.section_end("t_lvm_forward_ms", t_forward) + t_guidance = timer.section_start("t_lvm_apply_guidance_ms") + final_probs, padded_ids, valid_mask = self._guided_candidate_probs_gpu( + pending, gpu_embeddings + ) + timer.section_end("t_lvm_apply_guidance_ms", t_guidance) + except Exception as exc: + raise RuntimeError("LenVM GPU guidance path failed") from exc + else: + raise RuntimeError("Sparse LenVM sampling requires the GPU candidate path") + + send_rows_t = torch.tensor( + pending.send_batch_indices, device=device, dtype=torch.long + ) + t_sample = timer.section_start("t_lvm_guided_sample_ms") + send_token_ids = self._sample_rows_from_candidate_probs( + final_probs, + padded_ids, + valid_mask, + send_rows_t, + sampling_seed, + positions, + ) + timer.section_end("t_lvm_guided_sample_ms", t_sample) + row_parts.append(send_rows_t) + token_parts.append(send_token_ids) + + if not row_parts: + return None + + rows_t = torch.cat(row_parts, dim=0) + toks_t = torch.cat(token_parts, dim=0) + return LvmGuidedSampleResult(row_indices=rows_t, token_ids=toks_t) + + def sample_token_ids( + self, + probs: torch.Tensor, + reqs: Iterable[Any], + temperatures: torch.Tensor, + top_ps: torch.Tensor, + top_ks: torch.Tensor, + min_ps: torch.Tensor, + sampling_seed: Optional[torch.Tensor], + positions: Optional[torch.Tensor], + ) -> Optional[LvmGuidedSampleResult]: + """Return only the token ids that LenVM must override. + + Rows skipped by entropy fall back to the normal SGLang sampler, preserving the + native flashinfer top-k/top-p path instead of sampling the whole batch from a + full-vocab PyTorch probability tensor. + """ + timer = get_timer() + t_build = timer.section_start("t_lvm_build_pending_ms") + pending = self._build_pending( + probs, + reqs, + temperatures, + top_ps, + top_ks, + min_ps, + build_guided=False, + ) + timer.section_end("t_lvm_build_pending_ms", t_build) + if pending is None: + timer.set_meta( + lvm_n_reqs_with_guidance=0, + lvm_n_candidates=0, + lvm_has_deterministic=False, + lvm_has_fallback_sample=False, + ) + return None + return self._sample_from_pending(pending, reqs, sampling_seed, positions) + def apply( self, probs: torch.Tensor, @@ -1765,14 +2402,37 @@ def apply( Returns the modified probs tensor, or None when no guidance is needed (caller should use the original probs). """ + timer = get_timer() + t_provider = timer.section_start("t_lvm_get_inproc_ms") inproc = self._get_inproc_provider() + timer.section_end("t_lvm_get_inproc_ms", t_provider) if inproc not in (None, False): # Free KV cache for requests that have finished or aborted + t_clean = timer.section_start("t_lvm_clean_stale_ms") inproc.clean_stale_requests(set(r.rid for r in reqs)) + timer.section_end("t_lvm_clean_stale_ms", t_clean) + t_build = timer.section_start("t_lvm_build_pending_ms") pending = self._build_pending(probs, reqs, temperatures, top_ps, top_ks, min_ps) + timer.section_end("t_lvm_build_pending_ms", t_build) if pending is None: + timer.set_meta( + lvm_n_reqs_with_guidance=0, + lvm_n_candidates=0, + lvm_has_deterministic=False, + lvm_has_fallback_sample=False, + ) return None + timer.set_meta( + lvm_n_reqs_with_guidance=len(pending.send_batch_indices), + lvm_n_candidates=sum(len(x) for x in pending.candidate_ids_send), + lvm_has_deterministic=bool(pending.deterministic_rows) + or ( + pending.deterministic_row_indices is not None + and pending.deterministic_row_indices.numel() > 0 + ), + lvm_has_fallback_sample=pending.fallback_sample_rows is not None, + ) if pending.send_batch_indices: reqs_send = [pending.req_list[i] for i in pending.send_batch_indices] @@ -1781,21 +2441,29 @@ def apply( if inproc not in (None, False): try: rids_send = [req.rid for req in reqs_send] + t_forward = timer.section_start("t_lvm_forward_ms") inproc.tree_value_extend(rids_send, pending.prefix_ids_send, reqs_send) gpu_emb = inproc.tree_value_launch_gpu( rids_send, pending.candidate_ids_send, gpu_candidates=pending.gpu_candidates ) gpu_embeddings = inproc.tree_value_collect_gpu(gpu_emb) + timer.section_end("t_lvm_forward_ms", t_forward) + t_guidance = timer.section_start("t_lvm_apply_guidance_ms") self._apply_guidance_gpu(pending, gpu_embeddings) + timer.section_end("t_lvm_apply_guidance_ms", t_guidance) return pending.guided except Exception as exc: raise RuntimeError("LenVM GPU guidance path failed") from exc + t_forward = timer.section_start("t_lvm_forward_ms") lvm_values = self._post_tree_value( [req.rid for req in reqs_send], pending.prefix_ids_send, pending.candidate_ids_send, reqs_send ) + timer.section_end("t_lvm_forward_ms", t_forward) if lvm_values is None: return None + t_guidance = timer.section_start("t_lvm_apply_guidance_ms") self._apply_guidance(pending, lvm_values) + timer.section_end("t_lvm_apply_guidance_ms", t_guidance) return pending.guided diff --git a/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py b/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py index fe5f080..2b0287e 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py +++ b/sglang-LenVM/python/sglang/srt/lvm/lvm_inproc_runner.py @@ -13,6 +13,7 @@ from __future__ import annotations import logging +import os from typing import Dict, List, Optional, TYPE_CHECKING import torch @@ -21,6 +22,7 @@ TreeValueSpecInput, build_tree_value_custom_mask_and_positions, ) +from sglang.srt.lvm.timing import get_timer from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, @@ -120,7 +122,6 @@ def __init__(self, runner: "ModelRunner"): # the LenVM may have different hidden sizes, so sharing the global cache can # return wrong-shape visual embeddings and silently corrupt the VLM path. from sglang.srt.mem_cache.multimodal_cache import MultiModalStaticCache - import os lvm_cache_bytes = int(os.environ.get("SGLANG_LVM_VLM_CACHE_SIZE_MB", 512)) * 1024 * 1024 self._lvm_embedding_cache = MultiModalStaticCache(lvm_cache_bytes) @@ -301,7 +302,7 @@ def extend_prefix_batch( if isinstance(feature, torch.Tensor) and not feature.is_cuda: mm_item.feature = feature.to(device) - with self._lvm_embedding_cache_ctx(): + with torch.inference_mode(), self._lvm_embedding_cache_ctx(): runner.forward_extend(forward_batch) # Update kv_lens (KV for new tokens is now in the pool). @@ -337,6 +338,10 @@ def eval_candidates_batch_gpu( cand_lens = [len(cands) for cands in candidate_ids_per_req] seq_lens_list = [prefix_lens[i] + cand_lens[i] for i in range(len(rids))] total_cands = sum(cand_lens) + get_timer().set_meta( + lvm_candidate_tokens=total_cands, + lvm_candidate_max_len=max(cand_lens) if cand_lens else 0, + ) pool_indices = [self.kv_mgr.pool_indices[rid] for rid in rids] # Allocate temporary KV slots for candidate tokens. @@ -449,13 +454,252 @@ def eval_candidates_batch_gpu( forward_batch.mrope_positions = torch.cat(cand_mrope_chunks, dim=1) try: - logits_output = runner.forward_extend(forward_batch) + with torch.inference_mode(): + logits_output = runner.forward_extend(forward_batch) finally: # Free candidate KV slots immediately (they must not be cached). runner.token_to_kv_pool_allocator.free(out_cache_loc) return logits_output.embeddings + def eval_prefix_and_candidates_batch_gpu( + self, + rids: List[str], + prefix_ids_per_req: List[List[int]], + candidate_ids_per_req: List[List[int]], + gpu_candidates: Optional[tuple] = None, + ) -> List[torch.Tensor]: + """Evaluate uncached prefix suffixes and candidate tokens in one forward. + + This is exact tree-value evaluation, not an approximation: + - uncached prefix suffix tokens are forwarded causally and kept in KV cache; + - candidate tokens attend to the full prefix and only themselves; + - candidate KV slots are freed after the forward. + + It replaces the old hot path of `extend_prefix_batch` followed by + `eval_candidates_batch_gpu`, reducing one LenVM model launch per guided step. + """ + if not rids: + return [] + + runner = self.runner + device = self.device + + pool_indices = [self.kv_mgr.get_or_alloc(rid) for rid in rids] + cached_prefix_lens: List[int] = [] + full_prefix_lens: List[int] = [] + suffix_ids_per_req: List[List[int]] = [] + + for rid, prefix_ids in zip(rids, prefix_ids_per_req): + cached_len = self.kv_mgr.kv_len(rid) + target_len = len(prefix_ids) + if target_len < cached_len: + self.kv_mgr.retract(rid, target_len) + cached_len = target_len + cached_prefix_lens.append(cached_len) + full_prefix_lens.append(target_len) + suffix_ids_per_req.append(prefix_ids[cached_len:]) + + suffix_lens = [len(x) for x in suffix_ids_per_req] + cand_lens = [len(cands) for cands in candidate_ids_per_req] + extend_lens = [s + n for s, n in zip(suffix_lens, cand_lens)] + seq_lens_list = [ + full_prefix_lens[i] + cand_lens[i] for i in range(len(rids)) + ] + total_extend = sum(extend_lens) + total_suffix = sum(suffix_lens) + total_cands = sum(cand_lens) + get_timer().set_meta( + lvm_fused_suffix_tokens=total_suffix, + lvm_fused_candidate_tokens=total_cands, + lvm_fused_total_extend_tokens=total_extend, + lvm_fused_max_suffix_len=max(suffix_lens) if suffix_lens else 0, + lvm_fused_max_candidate_len=max(cand_lens) if cand_lens else 0, + ) + + if total_extend == 0: + return [] + + allocator = runner.token_to_kv_pool_allocator + if getattr(allocator, "page_size", 1) == 1: + out_cache_loc = allocator.alloc(total_extend) + else: + last_loc = torch.empty(len(pool_indices), dtype=torch.int64, device=device) + for i, (pool_idx, p_len) in enumerate(zip(pool_indices, cached_prefix_lens)): + last_loc[i] = ( + runner.req_to_token_pool.req_to_token[pool_idx, p_len - 1] + if p_len > 0 + else -1 + ) + out_cache_loc = allocator.alloc_extend( + torch.tensor(cached_prefix_lens, dtype=torch.int64, device=device), + torch.tensor(cached_prefix_lens, dtype=torch.int64), + torch.tensor(seq_lens_list, dtype=torch.int64, device=device), + torch.tensor(seq_lens_list, dtype=torch.int64), + last_loc, + total_extend, + ) + if out_cache_loc is None: + raise RuntimeError( + f"LVM KV pool OOM: cannot allocate {total_extend} slots for " + "fused prefix/candidate evaluation. Consider " + "--lvm-guided-inproc-mem-fraction-static." + ) + out_cache_loc = out_cache_loc.to(torch.int64) + + input_ids_flat: List[int] = [] + suffix_input_positions: List[int] = [] + suffix_input_ids: List[int] = [] + candidate_input_positions: List[int] = [] + candidate_cache_locs: List[torch.Tensor] = [] + pt = 0 + for pool_idx, p_len, s_len, n_cand, suffix, cands in zip( + pool_indices, + cached_prefix_lens, + suffix_lens, + cand_lens, + suffix_ids_per_req, + candidate_ids_per_req, + ): + if s_len: + runner.req_to_token_pool.write( + (pool_idx, slice(p_len, p_len + s_len)), + out_cache_loc[pt : pt + s_len], + ) + input_ids_flat.extend(suffix) + suffix_input_positions.extend(range(pt, pt + s_len)) + suffix_input_ids.extend(suffix) + if n_cand: + cand_start = pt + s_len + cand_end = cand_start + n_cand + cand_locs = out_cache_loc[cand_start:cand_end] + runner.req_to_token_pool.write( + (pool_idx, slice(p_len + s_len, p_len + s_len + n_cand)), + cand_locs, + ) + candidate_cache_locs.append(cand_locs) + candidate_input_positions.extend(range(cand_start, cand_end)) + if gpu_candidates is None: + input_ids_flat.extend(cands) + pt += s_len + n_cand + + if gpu_candidates is not None: + _, gpu_ids, gpu_mask = gpu_candidates + gpu_candidate_ids = gpu_ids[gpu_mask].to(torch.int64) + if int(gpu_candidate_ids.numel()) != total_cands: + raise RuntimeError( + f"LenVM GPU candidate id count mismatch: " + f"{gpu_candidate_ids.numel()} ids for {total_cands} candidates" + ) + if total_suffix == 0: + input_ids_t = gpu_candidate_ids + else: + # Suffix ids are host-side request history, but candidate ids are + # already on GPU from top-k filtering. Assemble the flattened + # [suffix, candidates] input without copying candidate ids through + # Python lists. + input_ids_t = torch.empty( + (total_extend,), + dtype=torch.int64, + device=device, + ) + if suffix_input_ids: + suffix_pos_t = torch.tensor( + suffix_input_positions, + dtype=torch.long, + device=device, + ) + suffix_ids_t = torch.tensor( + suffix_input_ids, + dtype=torch.int64, + device=device, + ) + input_ids_t.index_copy_(0, suffix_pos_t, suffix_ids_t) + if total_cands: + cand_pos_t = torch.tensor( + candidate_input_positions, + dtype=torch.long, + device=device, + ) + input_ids_t.index_copy_(0, cand_pos_t, gpu_candidate_ids) + else: + input_ids_t = torch.tensor( + input_ids_flat, + dtype=torch.int64, + device=device, + ) + + req_pool_indices_t = torch.tensor(pool_indices, dtype=torch.int64, device=device) + seq_lens_t = torch.tensor(seq_lens_list, dtype=torch.int32, device=device) + seq_lens_cpu_t = torch.tensor(seq_lens_list, dtype=torch.int32) + extend_prefix_lens_t = torch.tensor( + cached_prefix_lens, dtype=torch.int32, device=device + ) + extend_seq_lens_t = torch.tensor(extend_lens, dtype=torch.int32, device=device) + + custom_mask, positions = build_tree_value_custom_mask_and_positions( + prefix_lens=full_prefix_lens, + candidate_lens=cand_lens, + cached_prefix_lens=cached_prefix_lens, + device=device, + ) + spec_info = TreeValueSpecInput( + custom_mask=custom_mask, + positions=positions, + tree_value_prefix_lens=list(full_prefix_lens), + tree_value_candidate_lens=list(cand_lens), + tree_value_cached_prefix_lens=list(cached_prefix_lens), + ) + + extend_start_loc = torch.zeros(len(rids), dtype=torch.int32, device=device) + if len(rids) > 1: + extend_start_loc[1:] = torch.cumsum(extend_seq_lens_t[:-1], dim=0) + + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=len(rids), + input_ids=input_ids_t, + req_pool_indices=req_pool_indices_t, + seq_lens=seq_lens_t, + seq_lens_cpu=seq_lens_cpu_t, + out_cache_loc=out_cache_loc, + seq_lens_sum=sum(seq_lens_list), + positions=positions, + extend_num_tokens=total_extend, + extend_seq_lens=extend_seq_lens_t, + extend_prefix_lens=extend_prefix_lens_t, + extend_start_loc=extend_start_loc, + extend_prefix_lens_cpu=list(cached_prefix_lens), + extend_seq_lens_cpu=list(extend_lens), + req_to_token_pool=runner.req_to_token_pool, + token_to_kv_pool=runner.token_to_kv_pool, + attn_backend=runner.attn_backend, + return_logprob=False, + is_extend_in_batch=True, + is_prefill_only=True, + spec_algorithm=SpeculativeAlgorithm.NONE, + spec_info=spec_info, + global_forward_mode=ForwardMode.EXTEND, + ) + forward_batch.num_token_non_padded_cpu = total_extend + + candidate_cache_loc_t = ( + torch.cat(candidate_cache_locs) + if candidate_cache_locs + else torch.empty((0,), dtype=torch.int64, device=device) + ) + try: + with torch.inference_mode(): + logits_output = runner.forward_extend(forward_batch) + finally: + if candidate_cache_loc_t.numel() > 0: + runner.token_to_kv_pool_allocator.free(candidate_cache_loc_t) + + for rid, s_len in zip(rids, suffix_lens): + self.kv_mgr.kv_lens[rid] += s_len + + return logits_output.embeddings + def eval_candidates_batch( self, rids: List[str], diff --git a/sglang-LenVM/python/sglang/srt/lvm/timing.py b/sglang-LenVM/python/sglang/srt/lvm/timing.py new file mode 100644 index 0000000..46be08d --- /dev/null +++ b/sglang-LenVM/python/sglang/srt/lvm/timing.py @@ -0,0 +1,145 @@ +"""Lightweight per-step timing for LenVM-guided decoding. + +Set SGLANG_LVM_TIMING_LOG=/path/to/timing.jsonl to enable. The timer writes one +JSONL record per sampler step and also emits an aggregate summary next to it at +process exit. When unset, all methods are no-ops. +""" + +from __future__ import annotations + +import atexit +import json +import os +import threading +import time +from typing import Optional + + +class _Timer: + def __init__(self) -> None: + log_path = os.environ.get("SGLANG_LVM_TIMING_LOG") + self.enabled = bool(log_path) + self._log_path: Optional[str] = log_path + self._fh = None + self._step_id = 0 + self._seen_steps = 0 + self._lock = threading.Lock() + self._current_step: dict[str, object] = {} + self._totals: dict[str, float] = {} + self._counts: dict[str, int] = {} + self._meta_counts: dict[str, int] = {} + try: + self._summary_interval = max( + int(os.environ.get("SGLANG_LVM_TIMING_SUMMARY_INTERVAL", "500")), 1 + ) + except ValueError: + self._summary_interval = 500 + try: + self._skip_steps = max( + int(os.environ.get("SGLANG_LVM_TIMING_SKIP_STEPS", "0")), 0 + ) + except ValueError: + self._skip_steps = 0 + if self.enabled: + os.makedirs(os.path.dirname(self._log_path) or ".", exist_ok=True) + self._fh = open(self._log_path, "a", buffering=1) + atexit.register(self.close) + + def section_start(self, name: str) -> Optional[float]: + if not self.enabled: + return None + return time.perf_counter() + + def section_end(self, name: str, start: Optional[float]) -> None: + if not self.enabled or start is None: + return + elapsed_ms = (time.perf_counter() - start) * 1000.0 + prev = self._current_step.get(name, 0.0) + if isinstance(prev, (int, float)): + self._current_step[name] = float(prev) + elapsed_ms + else: + self._current_step[name] = elapsed_ms + + def set_meta(self, **kwargs) -> None: + if not self.enabled: + return + self._current_step.update(kwargs) + + def flush_step(self) -> None: + if not self.enabled or self._fh is None or not self._current_step: + return + with self._lock: + self._seen_steps += 1 + if self._seen_steps <= self._skip_steps: + self._current_step.clear() + return + self._step_id += 1 + record = {"step": self._step_id, **self._current_step} + self._fh.write(json.dumps(record) + "\n") + self._update_summary_from_record_locked(record) + self._current_step.clear() + if self._step_id % self._summary_interval == 0: + self._write_summary_locked() + + def _update_summary_from_record_locked(self, record: dict[str, object]) -> None: + for key, value in record.items(): + if key == "step": + continue + if isinstance(value, bool): + self._meta_counts[key] = self._meta_counts.get(key, 0) + int(value) + elif isinstance(value, (int, float)): + self._totals[key] = self._totals.get(key, 0.0) + float(value) + self._counts[key] = self._counts.get(key, 0) + 1 + + def _write_summary_locked(self) -> None: + if not self._log_path: + return + timing_keys = sorted(key for key in self._totals if key.endswith("_ms")) + numeric_keys = sorted(key for key in self._totals if not key.endswith("_ms")) + summary = { + "steps": self._step_id, + "means_ms": { + key: self._totals[key] / max(self._counts.get(key, 1), 1) + for key in timing_keys + }, + "means": { + key: self._totals[key] / max(self._counts.get(key, 1), 1) + for key in numeric_keys + }, + "totals": {key: self._totals[key] for key in sorted(self._totals)}, + "counts": dict(sorted(self._meta_counts.items())), + } + summary_path = f"{self._log_path}.summary.json" + with open(summary_path, "w") as fh: + json.dump(summary, fh, indent=2, sort_keys=True) + + def close(self) -> None: + if not self.enabled: + return + with self._lock: + if self._current_step: + self._seen_steps += 1 + if self._seen_steps > self._skip_steps: + self._step_id += 1 + record = {"step": self._step_id, **self._current_step} + if self._fh is not None: + self._fh.write(json.dumps(record) + "\n") + self._update_summary_from_record_locked(record) + self._current_step.clear() + + self._write_summary_locked() + + if self._fh is not None: + self._fh.close() + self._fh = None + self.enabled = False + + +_INSTANCE: Optional[_Timer] = None + + +def get_timer() -> _Timer: + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = _Timer() + return _INSTANCE diff --git a/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py b/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py index c84101a..84a560e 100644 --- a/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py +++ b/sglang-LenVM/python/sglang/srt/lvm/tree_value_spec.py @@ -41,6 +41,46 @@ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: # No token multiplier for DP buffers. return 1, 1 + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + """Build FlashInfer prefill metadata for LenVM tree-value masks.""" + from sglang.srt.layers.attention.utils import ( + create_flashinfer_kv_indices_triton, + ) + + device = req_pool_indices.device + bs = len(req_pool_indices) + q_lens, _k_lens, _mask_offsets, _pos_offsets = self._per_req_qk_and_offsets() + + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum( + torch.tensor(q_lens[:bs], dtype=torch.int32, device=device), dim=0 + ) + + kv_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + kv_indptr[1:] = torch.cumsum(paged_kernel_lens[:bs], dim=0) + + kv_indices = torch.empty( + int(paged_kernel_lens_sum) + 256, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, kv_indptr, qo_indptr, self.custom_mask + def _per_req_qk_and_offsets(self) -> Tuple[List[int], List[int], List[int], List[int]]: """ Compute per-request: @@ -265,4 +305,3 @@ def build_tree_value_custom_mask_and_positions( custom_mask = torch.from_numpy(mask_buf).to(device=device, non_blocking=True) positions = torch.from_numpy(pos_buf).to(device=device, non_blocking=True) return custom_mask, positions - diff --git a/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py b/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py index 536042e..8c0ef67 100644 --- a/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py +++ b/sglang-LenVM/python/sglang/srt/models/qwen2_5_vl_lvm.py @@ -67,7 +67,6 @@ def forward( return hidden_states token_values = self.v_head(hidden_states).squeeze(-1) - spec = getattr(forward_batch, "spec_info", None) prefix_lens = ( getattr(spec, "tree_value_prefix_lens", None) if spec is not None else None @@ -87,14 +86,14 @@ def forward( out: list[torch.Tensor] = [] offset = 0 for i, ext_len in enumerate(extend_lens): - vals_i = token_values[offset : offset + ext_len] - offset += ext_len - prefix_len = int(prefix_lens[i]) cand_len = int(cand_lens[i]) cached_prefix_len = int(cached_prefix_lens[i]) cand_offset = max(prefix_len - cached_prefix_len, 0) - out.append(vals_i[cand_offset : cand_offset + cand_len]) + out.append( + token_values[offset + cand_offset : offset + cand_offset + cand_len] + ) + offset += ext_len return EmbeddingPoolerOutput(embeddings=out) if forward_batch.extend_seq_lens is None: diff --git a/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py b/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py index 482f0a6..84b67d9 100644 --- a/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py +++ b/sglang-LenVM/python/sglang/srt/models/qwen2_lvm.py @@ -96,14 +96,12 @@ def forward( out: list[torch.Tensor] = [] offset = 0 for i, ext_len in enumerate(extend_lens): - vals_i = token_values[offset : offset + ext_len] - offset += ext_len - L = int(prefix_lens[i]) N = int(cand_lens[i]) P = int(cached_prefix_lens[i]) cand_offset = max(L - P, 0) - out.append(vals_i[cand_offset : cand_offset + N]) + out.append(token_values[offset + cand_offset : offset + cand_offset + N]) + offset += ext_len return EmbeddingPoolerOutput(embeddings=out) # Otherwise (e.g., /encode), return tokenwise values for the forwarded tokens. @@ -200,4 +198,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = [ Qwen2ForLengthValueModel, ] - diff --git a/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py b/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py index f9b41ba..b0272a4 100644 --- a/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py +++ b/sglang-LenVM/python/sglang/srt/models/qwen3_lvm.py @@ -93,14 +93,12 @@ def forward( out: list[torch.Tensor] = [] offset = 0 for i, ext_len in enumerate(extend_lens): - vals_i = token_values[offset : offset + ext_len] - offset += ext_len - L = int(prefix_lens[i]) N = int(cand_lens[i]) P = int(cached_prefix_lens[i]) cand_offset = max(L - P, 0) - out.append(vals_i[cand_offset : cand_offset + N]) + out.append(token_values[offset + cand_offset : offset + cand_offset + N]) + offset += ext_len return EmbeddingPoolerOutput(embeddings=out) # Otherwise (e.g., /encode), return tokenwise values for the forwarded tokens.