Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 57 additions & 9 deletions sglang-LenVM/python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda, is_npu
from sglang.srt.lvm.lvm_guided_sampling import LvmGuidedSampler
from sglang.srt.lvm.timing import get_timer

if is_cuda():
from sgl_kernel import (
Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self, model_runner=None):
self.lvm_guided_sampler = LvmGuidedSampler.from_server_args(
get_global_server_args(), model_runner=model_runner
)
self._timer = get_timer()

if is_dp_attention_enabled():
self.tp_sync_group = get_attention_tp_group().device_group
Expand Down Expand Up @@ -134,16 +136,27 @@ def forward(
positions: The positions of the tokens in the sequence. Used for deterministic sampling
to get the unique seed for each position.
"""
timer = self._timer
t_forward = timer.section_start("t_sampler_total_ms")
guided_applied = False
batch_size_meta = int(logits_output.next_token_logits.shape[0])

logits = logits_output.next_token_logits

# Preprocess logits (custom processors and NaN handling)
t_preprocess = timer.section_start("t_preprocess_logits_ms")
logits = self._preprocess_logits(logits, sampling_info)
timer.section_end("t_preprocess_logits_ms", t_preprocess)

if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
t_sample = timer.section_start("t_sample_ms")
batch_next_token_ids = torch.argmax(logits, -1)
timer.section_end("t_sample_ms", t_sample)
if return_logprob:
t_logprob = timer.section_start("t_logprob_ms")
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
timer.section_end("t_logprob_ms", t_logprob)
else:
can_sample_directly_from_probs = (
not sampling_info.need_top_p_sampling
Expand All @@ -164,6 +177,7 @@ def forward(
)

# Post process logits
t_pre_lvm = timer.section_start("t_pre_lvm_ms")
logits.div_(sampling_info.temperatures)

# Per-token temperature scaling: for boosted tokens, divide their
Expand All @@ -178,34 +192,48 @@ def forward(
return_logprob and not SGLANG_RETURN_ORIGINAL_LOGPROB
):
logits[:] = torch.softmax(logits, dim=-1)
timer.section_end("t_pre_lvm_ms", t_pre_lvm)
probs = logits
del logits

guided_applied = False
guided_sample_result = None
if self.lvm_guided_sampler is not None and sampling_info.reqs is not None:
guided_probs = self.lvm_guided_sampler.apply(
t_lvm = timer.section_start("t_lvm_apply_outer_ms")
guided_sample_result = self.lvm_guided_sampler.sample_token_ids(
probs,
sampling_info.reqs,
sampling_info.temperatures,
sampling_info.top_ps,
sampling_info.top_ks,
sampling_info.min_ps,
sampling_info.sampling_seed,
positions,
)
if guided_probs is not None:
probs = guided_probs
guided_applied = True

if guided_applied:
can_sample_directly_from_probs = True
timer.section_end("t_lvm_apply_outer_ms", t_lvm)

if can_sample_directly_from_probs:
if (
guided_sample_result is not None
and guided_sample_result.row_indices.numel() == probs.shape[0]
):
batch_next_token_ids = torch.empty(
probs.shape[0], dtype=torch.int32, device=probs.device
)
batch_next_token_ids[guided_sample_result.row_indices] = (
guided_sample_result.token_ids
)
guided_sample_result = None
guided_applied = True
elif can_sample_directly_from_probs:
# when we don't need top-k, top-p, or min-p sampling, we can directly sample from the probs
t_sample = timer.section_start("t_sample_ms")
batch_next_token_ids = sampling_from_probs_torch(
probs,
sampling_seed=sampling_info.sampling_seed,
positions=positions,
)
timer.section_end("t_sample_ms", t_sample)
else:
t_sample = timer.section_start("t_sample_ms")
if get_global_server_args().sampling_backend == "flashinfer":
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
Expand Down Expand Up @@ -244,8 +272,18 @@ def forward(
raise ValueError(
f"Invalid sampling backend: {get_global_server_args().sampling_backend}"
)
timer.section_end("t_sample_ms", t_sample)

if guided_sample_result is not None:
t_override = timer.section_start("t_guided_override_ms")
batch_next_token_ids[guided_sample_result.row_indices] = (
guided_sample_result.token_ids
)
timer.section_end("t_guided_override_ms", t_override)
guided_applied = True

if return_logprob:
t_logprob = timer.section_start("t_logprob_ms")
if get_global_server_args().rl_on_policy_target is not None:
logprobs = logprobs_via_logsoftmax_kernel
del logprobs_via_logsoftmax_kernel
Expand All @@ -257,6 +295,7 @@ def forward(
del probs_without_temp_scaling
else:
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
timer.section_end("t_logprob_ms", t_logprob)

# Attach logprobs to logits_output (in-place modification)
if return_logprob:
Expand Down Expand Up @@ -285,12 +324,21 @@ def forward(
# In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
# When using xgrammar, this becomes more likely so we also do the sync when grammar is used.

t_tp_sync = timer.section_start("t_tp_sync_ms")
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=self.tp_sync_group,
)
timer.section_end("t_tp_sync_ms", t_tp_sync)

timer.section_end("t_sampler_total_ms", t_forward)
timer.set_meta(
lvm_active=bool(guided_applied),
batch_size=batch_size_meta,
is_greedy=bool(sampling_info.is_all_greedy),
)
timer.flush_step()
return batch_next_token_ids

def compute_logprobs_only(
Expand Down
174 changes: 174 additions & 0 deletions sglang-LenVM/python/sglang/srt/lvm/entropy_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Triton helpers for LenVM entropy-based guidance skips."""

from __future__ import annotations

import os
from typing import Optional

import torch


_DISABLE_TRITON_ENTROPY = os.getenv("LVM_DISABLE_TRITON_ENTROPY", "").lower() in {
"1",
"true",
"yes",
"on",
}
_TRITON_AVAILABLE: Optional[bool] = None


def _fallback_entropy_from_probs(
probs: torch.Tensor,
rows: Optional[torch.Tensor] = None,
) -> torch.Tensor:
p = probs.index_select(0, rows) if rows is not None else probs
p = p.float()
s = p.sum(dim=-1)
p = p / s.clamp(min=1e-20).view(-1, 1)
return torch.special.entr(p).sum(dim=-1)


def _next_power_of_2(x: int) -> int:
return 1 << (max(int(x), 1) - 1).bit_length()


def full_vocab_entropy_from_probs(
probs: torch.Tensor,
rows: Optional[torch.Tensor] = None,
*,
block_size: int = 4096,
) -> torch.Tensor:
"""Compute entropy over the full vocabulary distribution.

The input is a probability tensor, not logits. To match the existing code
exactly enough for thresholding, rows are renormalized by their full-vocab
sum:

H(p / sum(p)) = log(sum(p)) - sum(p * log(p)) / sum(p)

Accumulation is FP32. This avoids materializing an extra selected
[n_rows, vocab] tensor and the intermediate `entr` tensor.
"""

if (
_DISABLE_TRITON_ENTROPY
or not probs.is_cuda
or probs.dim() != 2
or probs.shape[-1] == 0
):
return _fallback_entropy_from_probs(probs, rows)

global _TRITON_AVAILABLE
if _TRITON_AVAILABLE is False:
return _fallback_entropy_from_probs(probs, rows)

try:
import triton
import triton.language as tl
except Exception:
_TRITON_AVAILABLE = False
return _fallback_entropy_from_probs(probs, rows)

_TRITON_AVAILABLE = True

n_total_rows = int(probs.shape[0])
vocab_size = int(probs.shape[1])
if rows is None:
n_rows = n_total_rows
rows_t = torch.empty((0,), dtype=torch.int64, device=probs.device)
has_rows = False
else:
rows_t = rows.to(device=probs.device, dtype=torch.int64)
n_rows = int(rows_t.numel())
has_rows = True
if n_rows == 0:
return torch.empty((0,), dtype=torch.float32, device=probs.device)

block_size = int(block_size)
if block_size <= 0:
block_size = 4096
n_blocks = triton.cdiv(vocab_size, block_size)
reduce_block = _next_power_of_2(n_blocks)

partial_sum = torch.empty((n_rows, n_blocks), dtype=torch.float32, device=probs.device)
partial_p_log_p = torch.empty_like(partial_sum)
out = torch.empty((n_rows,), dtype=torch.float32, device=probs.device)

try:
_entropy_partial_kernel[(n_rows, n_blocks)](
probs,
rows_t,
partial_sum,
partial_p_log_p,
vocab_size,
int(probs.stride(0)),
n_blocks,
HAS_ROWS=has_rows,
BLOCK_SIZE=block_size,
num_warps=8,
)
_entropy_reduce_kernel[(n_rows,)](
partial_sum,
partial_p_log_p,
out,
n_blocks,
BLOCK_N=reduce_block,
num_warps=8,
)
return out
except Exception:
_TRITON_AVAILABLE = False
return _fallback_entropy_from_probs(probs, rows)


try:
import triton
import triton.language as tl

@triton.jit
def _entropy_partial_kernel(
probs,
rows,
partial_sum,
partial_p_log_p,
vocab_size: tl.constexpr,
row_stride: tl.constexpr,
n_blocks: tl.constexpr,
HAS_ROWS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
row_pos = tl.program_id(0)
block_id = tl.program_id(1)
if HAS_ROWS:
row = tl.load(rows + row_pos)
else:
row = row_pos
offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vocab_size
p = tl.load(probs + row * row_stride + offsets, mask=mask, other=0.0).to(tl.float32)
p_log_p = tl.where(p > 0.0, p * tl.log(p), 0.0)
out_offset = row_pos * n_blocks + block_id
tl.store(partial_sum + out_offset, tl.sum(p, axis=0))
tl.store(partial_p_log_p + out_offset, tl.sum(p_log_p, axis=0))

@triton.jit
def _entropy_reduce_kernel(
partial_sum,
partial_p_log_p,
out,
n_blocks: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row = tl.program_id(0)
offsets = tl.arange(0, BLOCK_N)
mask = offsets < n_blocks
s = tl.load(partial_sum + row * n_blocks + offsets, mask=mask, other=0.0)
plogp = tl.load(partial_p_log_p + row * n_blocks + offsets, mask=mask, other=0.0)
sum_p = tl.sum(s, axis=0)
sum_p_log_p = tl.sum(plogp, axis=0)
ent = tl.where(sum_p > 0.0, tl.log(sum_p) - sum_p_log_p / sum_p, 0.0)
tl.store(out + row, ent)

except Exception:
# Import-time fallback for CPU-only environments and Triton import issues.
_TRITON_AVAILABLE = False
Loading