Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions astraflow/core/workflow/impl/rlvr.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def default_get_input_ids_fn(
tokenize=True,
add_generation_prompt=True,
enable_thinking=enable_thinking,
return_dict=False,
)
return list(input_ids)

Expand Down
1 change: 0 additions & 1 deletion astraflow/raas/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ class SGLangConfig:
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: int | None = None
Expand Down
7 changes: 4 additions & 3 deletions astraflow/raas/engine/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def check_health(self, base_url):
health_req = self.backend.get_health_check_request()
url = f"{base_url}{health_req.endpoint}"
response = requests.request(
health_req.method, url, json=health_req.payload, timeout=5
health_req.method, url, json=health_req.payload, timeout=20
)
return response.status_code == 200
except requests.exceptions.RequestException:
Expand Down Expand Up @@ -748,8 +748,9 @@ def load_weights_from_path(
For LoRA adapters (``use_lora=True``): unloads the old adapter,
loads the new one, then flushes the KV cache via ``/flush_cache``
to discard stale entries computed with the old LoRA weights.
Requires ``LoRAAbortReleasePatch`` so that aborted requests
properly release their ``lora_registry`` counter.
Relies on sglang releasing the ``lora_registry`` counter for
aborted requests (fixed upstream in
``TokenizerManager._handle_abort_finish_reason`` as of 0.5.12).
"""
import time as _time

Expand Down
2 changes: 0 additions & 2 deletions astraflow/raas/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,12 @@ def _validate_patch_results(results: Dict[str, bool], strict: bool) -> None:
def _run_sglang_patches(strict: bool) -> bool:
from astraflow.raas.patch.sglang import (
HttpServerPatch,
LoRAAbortReleasePatch,
ServerArgsPatch,
)

manager = PatchManager()
manager.register(ServerArgsPatch())
manager.register(HttpServerPatch())
manager.register(LoRAAbortReleasePatch())

results = manager.apply_all()
_log_patch_results(results)
Expand Down
54 changes: 0 additions & 54 deletions astraflow/raas/patch/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
can register with RaaS at startup.
2. HttpServerPatch — register SGLang instance with the rollout manager
during ``launch_server``.
3. LoRAAbortReleasePatch — fix missing ``lora_registry.release()`` in the
abort path so that LoRA weight updates via abort+unload don't hang.
"""

import logging
Expand Down Expand Up @@ -96,55 +94,3 @@ def patched_launch_server(server_args, *args, **kwargs):

traceback.print_exc()
return False


class LoRAAbortReleasePatch(BasePatch):
"""Fix missing ``lora_registry.release()`` in the abort path.

When a LoRA request is aborted from the waiting queue, SGLang's
``_handle_abort_req`` does NOT call ``lora_registry.release()``,
leaking the ``ConcurrentCounter``. This causes
``wait_for_unload()`` to hang forever when we try to swap LoRA
adapters via abort + unload.

The normal completion path (``_handle_batch_output``) and the
scheduler error path both release correctly — only the waiting-queue
abort path is missing the call.

This patch wraps ``_handle_abort_req`` to add the missing release,
mirroring the pattern at ``tokenizer_manager.py:1679-1680``.
"""

def apply(self) -> bool:
try:
import asyncio

from sglang.srt.managers.tokenizer_manager import TokenizerManager

original = TokenizerManager._handle_abort_req

if self._is_patched(original, "handle_abort_req"):
return True

def patched_handle_abort_req(self_tm, recv_obj):
original(self_tm, recv_obj)

# Release LoRA counter for aborted requests — mirrors the
# normal completion path at tokenizer_manager.py:1679-1680.
if self_tm.server_args.enable_lora:
state = self_tm.rid_to_state.get(recv_obj.rid)
if (
state is not None
and getattr(state.obj, "lora_path", None)
):
asyncio.create_task(
self_tm.lora_registry.release(state.obj.lora_id)
)

self._mark_as_patched(patched_handle_abort_req, "handle_abort_req")
TokenizerManager._handle_abort_req = patched_handle_abort_req

return True
except Exception as e:
logger.error(f"LoRAAbortReleasePatch failed: {e}")
return False
12 changes: 11 additions & 1 deletion astraflow/raas/server/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,17 @@ def get_status(self) -> dict[str, Any]:
# ------------------------------------------------------------------

_HEALTH_MONITOR_INTERVAL = 10.0 # seconds between checks
_HEALTH_MONITOR_MAX_FAILURES = 3 # consecutive failures before exit
# sglang 0.5.12's /health round-trips through the scheduler, which is
# saturated for ~30-40s during the initial unchunked prefill of ~2048
# reqs/engine, so the old 3-strike (30s) watchdog false-positive-killed a
# busy-but-alive engine before the first rollout batch. A crashed engine
# refuses connections instantly, so dead-engine detection time is
# INTERVAL * MAX_FAILURES = ~50s here; the 20s probe timeout only extends
# cycles for an alive-but-slow engine (which we want to tolerate, up to
# ~100s worst case). 5 strikes covers the ~35-40s prefill ramp (a slow but
# eventually-200 /health resets the counter) while catching a real death
# in ~50s.
_HEALTH_MONITOR_MAX_FAILURES = 5 # consecutive failures before exit
# Maximum time a weight update is allowed to legitimately stall the
# engine before the monitor force-probes anyway. A normal full pull +
# apply + load runs ~60-70s end-to-end, deltas ~30-40s; 90s is a
Expand Down
5 changes: 4 additions & 1 deletion astraflow/raas/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ def gethostname():
def gethostip():
return socket.gethostbyname(socket.gethostname())

_MAX_FREE_PORT = 55535

def find_free_ports(
count: int, port_range: tuple = (1024, 65535), exclude_ports: set[int] | None = None
count: int,
port_range: tuple = (1024, _MAX_FREE_PORT),
exclude_ports: set[int] | None = None,
) -> list[int]:
"""
Find multiple free ports within a specified range.
Expand Down
1 change: 0 additions & 1 deletion astraflow/train_worker/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,6 @@ class SGLangConfig:
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_ep_moe: bool = False
enable_torch_compile: bool = False
torch_compile_max_bs: int = 32
cuda_graph_max_bs: int | None = None
Expand Down
2 changes: 1 addition & 1 deletion astraflow/train_worker/tools/validation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class BaseInstallationValidator:
# Subclasses can override or extend this
CUDA_SUBMODULES = {
"torch": ["torch.cuda"],
"sglang": ["sgl_kernel", "sgl_kernel.flash_attn"],
"sglang": ["sglang_kernel", "sglang_kernel.flash_attn"],
"vllm": ["vllm._C"],
"flash-attn": ["flash_attn_2_cuda"],
"megatron-core": [
Expand Down
2 changes: 2 additions & 0 deletions astraflow/train_worker/utils/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def apply_fsdp2(model, fsdp_kwargs, wrap_policy):

if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
else:
fsdp_transformer_layer_cls_to_wrap = list(fsdp_transformer_layer_cls_to_wrap)

assert (
len(fsdp_transformer_layer_cls_to_wrap) > 0
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.sglang
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Cache chain: basic → sglang
# Layers up to "uv pip install -e ." are identical to Dockerfile.basic.
FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04
FROM nvidia/cuda:13.0.0-cudnn-devel-ubuntu24.04

SHELL ["/bin/bash", "-lc"]

Expand Down
2 changes: 1 addition & 1 deletion docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
| ------------------- | ------------------------------- | ---------------- |
| `Dockerfile.sglang` | astraflow + SGLang + flash-attn | `-e ".[sglang]"` |

The image is based on `nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04` with Python 3.12
The image is based on `nvidia/cuda:13.0.0-cudnn-devel-ubuntu24.04` with Python 3.12
managed by [uv](https://docs.astral.sh/uv/).

## Pull pre-built image
Expand Down
40 changes: 36 additions & 4 deletions docs/en/get-started/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,63 @@ conda activate astraflow
### Step 2: Install uv (fast pip replacement)

```bash
pip install uv
pip install -U "uv>=0.10"
```

> **uv ≥ 0.10 is required.** `pyproject.toml` uses `[tool.uv]` settings
> (`extra-build-dependencies`, `override-dependencies`) that older uv
> releases don't recognize. When uv hits an unknown `[tool.uv]` key it
> silently ignores the *entire* `[tool.uv]` table, so the
> `transformers==5.6.1` override (which must beat sglang's `==5.6.0` pin)
> is dropped and the install fails with an unsolvable
> `transformers` conflict. The Docker images install the latest uv via the
> official installer and are unaffected.

### Step 3: Install AstraFlow (core + dev tools)

```bash
uv pip install -e ".[dev]"
```

This installs all core dependencies (~260 packages) including PyTorch 2.8.0,
Transformers 4.57.1, Megatron-Core 0.13.1, Ray, W&B, and dev tools (pytest, ruff,
This installs all core dependencies (~260 packages) including PyTorch 2.11.0,
Transformers 5.6.1, Megatron-Core 0.13.1, Ray, W&B, and dev tools (pytest, ruff,
ipython).

### Step 4: Install Flash Attention and SGLang

#### Flash Attention

This is FlashAttention-**2** (`import flash_attn`), used by the FSDP trainer. It
is excluded from uv resolution (see `pyproject.toml` `[tool.uv]`) and built from
source, so it needs the CUDA 13 toolchain and a roomy build-temp directory:

```bash
# nvcc must be on PATH and match torch's CUDA (13.0 for torch 2.11+cu130)
export CUDA_HOME=/usr/local/cuda-13.0
export PATH="$CUDA_HOME/bin:$PATH"

# nvcc writes GBs of intermediate files to $TMPDIR. Point it at local scratch
# with plenty of space — NOT a small/NFS-quota'd home, or the build fails with
# "nvFatbin error: empty input" or "Disk quota exceeded" from truncated temps.
export TMPDIR=/tmp/fa-build && mkdir -p "$TMPDIR"

uv pip install "flash-attn==2.8.3" --no-build-isolation
```

> On a single-GPU-arch box you can speed up the build and shrink its footprint
> with `FLASH_ATTN_CUDA_ARCHS=<arch> NVCC_THREADS=1` (e.g. `90` for H100, `80`
> for A100, `89` for L40/4090). These are optional — the real requirement is a
> roomy `TMPDIR`.

#### SGLang (inference backend)

Install via the project extra so uv applies the `[tool.uv]` overrides (the
`transformers==5.6.1` pin and the `flash-attn-4` pre-release allowance). SGLang
pulls in FlashAttention-**4** (`flash-attn-4`, a pre-release wheel) automatically
for its own attention backend — you do not install that one yourself.

```bash
uv pip install "sglang==0.5.5.post1"
uv pip install -e ".[sglang]"
```

### Step 5: Verify installation
Expand Down
28 changes: 21 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
# 4 - Beta
# 5 - Production/Stable
"Development Status :: 3 - Alpha",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.9",
"Environment :: GPU :: NVIDIA CUDA :: 13",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
Expand All @@ -41,17 +41,17 @@ classifiers = [

dependencies = [
# Core ML/AI libraries
"torch==2.8.0",
"torch==2.11.0",
"torchaudio",
"torchvision",
"torchdata",
"huggingface_hub",
"datasets>=3.0.0",
"transformers==4.57.1",
"transformers==5.6.1",
"megatron-core==0.13.1",
"mbridge==0.13.0",
"torch_memory_saver==0.0.9",
"peft",
"torch_memory_saver==0.0.9.post1",
"peft>=0.18.0",
"qwen_agent",
"openai-agents",

Expand Down Expand Up @@ -140,7 +140,7 @@ te = [
]

sglang = [
"sglang==0.5.5.post1",
"sglang==0.5.12.post1",
]

vllm = [
Expand Down Expand Up @@ -192,13 +192,27 @@ include = ["astraflow*"]
exclude = ["tests*", "docs*", "examples*"]

[tool.uv]
# sglang 0.5.12 depends on flash-attn-4>=4.0.0b9 (a pre-release wheel, pulled
# in automatically as a sglang dependency). Without this, `uv pip install
# -e ".[sglang]"` fails to resolve with "pre-releases weren't enabled".
prerelease = "allow"
exclude-dependencies = ["flash-attn"]
override-dependencies=[
"outlines-core==0.1.26",
# sglang 0.5.12 pins transformers==5.6.0, which has a flash-attention bug
# (unconditional s_aux.to() crashes non-sink models like Qwen3). 5.6.1 is
# a patch release that fixes it; override sglang's exact pin to pick it up.
"transformers==5.6.1",
# sglang requires an unbounded "kernels", so uv resolves the latest (0.15+),
# but transformers 5.6.1 only supports kernels<0.13 (its hub_kernels module
# calls LayerRepository() without a revision/version, which 0.15 rejects ->
# `import sglang` crashes with "Either a revision or a version must be
# specified."). Pin to the range transformers 5.6.1 was built against.
"kernels>=0.12.0,<0.13",
]

[tool.uv.extra-build-dependencies]
flash-attn = ["torch==2.8.0"]
flash-attn = ["torch==2.11.0"]

[tool.pytest.ini_options]
pythonpath = ["."]
Expand Down
Loading