diff --git a/astraflow/core/workflow/impl/rlvr.py b/astraflow/core/workflow/impl/rlvr.py index 7fd78c0..68a04f0 100644 --- a/astraflow/core/workflow/impl/rlvr.py +++ b/astraflow/core/workflow/impl/rlvr.py @@ -39,6 +39,7 @@ def default_get_input_ids_fn( tokenize=True, add_generation_prompt=True, enable_thinking=enable_thinking, + return_dict=False, ) return list(input_ids) diff --git a/astraflow/raas/api/cli_args.py b/astraflow/raas/api/cli_args.py index b79af81..97a020f 100644 --- a/astraflow/raas/api/cli_args.py +++ b/astraflow/raas/api/cli_args.py @@ -343,7 +343,6 @@ class SGLangConfig: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False - enable_ep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int | None = None diff --git a/astraflow/raas/engine/remote_inf_engine.py b/astraflow/raas/engine/remote_inf_engine.py index 3f3d6ea..0342e43 100644 --- a/astraflow/raas/engine/remote_inf_engine.py +++ b/astraflow/raas/engine/remote_inf_engine.py @@ -380,7 +380,7 @@ def check_health(self, base_url): health_req = self.backend.get_health_check_request() url = f"{base_url}{health_req.endpoint}" response = requests.request( - health_req.method, url, json=health_req.payload, timeout=5 + health_req.method, url, json=health_req.payload, timeout=20 ) return response.status_code == 200 except requests.exceptions.RequestException: @@ -748,8 +748,9 @@ def load_weights_from_path( For LoRA adapters (``use_lora=True``): unloads the old adapter, loads the new one, then flushes the KV cache via ``/flush_cache`` to discard stale entries computed with the old LoRA weights. - Requires ``LoRAAbortReleasePatch`` so that aborted requests - properly release their ``lora_registry`` counter. + Relies on sglang releasing the ``lora_registry`` counter for + aborted requests (fixed upstream in + ``TokenizerManager._handle_abort_finish_reason`` as of 0.5.12). """ import time as _time diff --git a/astraflow/raas/patch/__init__.py b/astraflow/raas/patch/__init__.py index 5d093ce..e291e04 100644 --- a/astraflow/raas/patch/__init__.py +++ b/astraflow/raas/patch/__init__.py @@ -83,14 +83,12 @@ def _validate_patch_results(results: Dict[str, bool], strict: bool) -> None: def _run_sglang_patches(strict: bool) -> bool: from astraflow.raas.patch.sglang import ( HttpServerPatch, - LoRAAbortReleasePatch, ServerArgsPatch, ) manager = PatchManager() manager.register(ServerArgsPatch()) manager.register(HttpServerPatch()) - manager.register(LoRAAbortReleasePatch()) results = manager.apply_all() _log_patch_results(results) diff --git a/astraflow/raas/patch/sglang.py b/astraflow/raas/patch/sglang.py index 0fd3c00..03481ec 100644 --- a/astraflow/raas/patch/sglang.py +++ b/astraflow/raas/patch/sglang.py @@ -7,8 +7,6 @@ can register with RaaS at startup. 2. HttpServerPatch — register SGLang instance with the rollout manager during ``launch_server``. -3. LoRAAbortReleasePatch — fix missing ``lora_registry.release()`` in the - abort path so that LoRA weight updates via abort+unload don't hang. """ import logging @@ -96,55 +94,3 @@ def patched_launch_server(server_args, *args, **kwargs): traceback.print_exc() return False - - -class LoRAAbortReleasePatch(BasePatch): - """Fix missing ``lora_registry.release()`` in the abort path. - - When a LoRA request is aborted from the waiting queue, SGLang's - ``_handle_abort_req`` does NOT call ``lora_registry.release()``, - leaking the ``ConcurrentCounter``. This causes - ``wait_for_unload()`` to hang forever when we try to swap LoRA - adapters via abort + unload. - - The normal completion path (``_handle_batch_output``) and the - scheduler error path both release correctly — only the waiting-queue - abort path is missing the call. - - This patch wraps ``_handle_abort_req`` to add the missing release, - mirroring the pattern at ``tokenizer_manager.py:1679-1680``. - """ - - def apply(self) -> bool: - try: - import asyncio - - from sglang.srt.managers.tokenizer_manager import TokenizerManager - - original = TokenizerManager._handle_abort_req - - if self._is_patched(original, "handle_abort_req"): - return True - - def patched_handle_abort_req(self_tm, recv_obj): - original(self_tm, recv_obj) - - # Release LoRA counter for aborted requests — mirrors the - # normal completion path at tokenizer_manager.py:1679-1680. - if self_tm.server_args.enable_lora: - state = self_tm.rid_to_state.get(recv_obj.rid) - if ( - state is not None - and getattr(state.obj, "lora_path", None) - ): - asyncio.create_task( - self_tm.lora_registry.release(state.obj.lora_id) - ) - - self._mark_as_patched(patched_handle_abort_req, "handle_abort_req") - TokenizerManager._handle_abort_req = patched_handle_abort_req - - return True - except Exception as e: - logger.error(f"LoRAAbortReleasePatch failed: {e}") - return False diff --git a/astraflow/raas/server/manager.py b/astraflow/raas/server/manager.py index e985f2c..6dcf672 100644 --- a/astraflow/raas/server/manager.py +++ b/astraflow/raas/server/manager.py @@ -1293,7 +1293,17 @@ def get_status(self) -> dict[str, Any]: # ------------------------------------------------------------------ _HEALTH_MONITOR_INTERVAL = 10.0 # seconds between checks - _HEALTH_MONITOR_MAX_FAILURES = 3 # consecutive failures before exit + # sglang 0.5.12's /health round-trips through the scheduler, which is + # saturated for ~30-40s during the initial unchunked prefill of ~2048 + # reqs/engine, so the old 3-strike (30s) watchdog false-positive-killed a + # busy-but-alive engine before the first rollout batch. A crashed engine + # refuses connections instantly, so dead-engine detection time is + # INTERVAL * MAX_FAILURES = ~50s here; the 20s probe timeout only extends + # cycles for an alive-but-slow engine (which we want to tolerate, up to + # ~100s worst case). 5 strikes covers the ~35-40s prefill ramp (a slow but + # eventually-200 /health resets the counter) while catching a real death + # in ~50s. + _HEALTH_MONITOR_MAX_FAILURES = 5 # consecutive failures before exit # Maximum time a weight update is allowed to legitimately stall the # engine before the monitor force-probes anyway. A normal full pull + # apply + load runs ~60-70s end-to-end, deltas ~30-40s; 90s is a diff --git a/astraflow/raas/utils/network.py b/astraflow/raas/utils/network.py index 28a1a14..f88bb68 100644 --- a/astraflow/raas/utils/network.py +++ b/astraflow/raas/utils/network.py @@ -9,9 +9,12 @@ def gethostname(): def gethostip(): return socket.gethostbyname(socket.gethostname()) +_MAX_FREE_PORT = 55535 def find_free_ports( - count: int, port_range: tuple = (1024, 65535), exclude_ports: set[int] | None = None + count: int, + port_range: tuple = (1024, _MAX_FREE_PORT), + exclude_ports: set[int] | None = None, ) -> list[int]: """ Find multiple free ports within a specified range. diff --git a/astraflow/train_worker/api/cli_args.py b/astraflow/train_worker/api/cli_args.py index 8ff1be0..096324e 100644 --- a/astraflow/train_worker/api/cli_args.py +++ b/astraflow/train_worker/api/cli_args.py @@ -839,7 +839,6 @@ class SGLangConfig: disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False - enable_ep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: int | None = None diff --git a/astraflow/train_worker/tools/validation_base.py b/astraflow/train_worker/tools/validation_base.py index 79bf783..b5ea404 100644 --- a/astraflow/train_worker/tools/validation_base.py +++ b/astraflow/train_worker/tools/validation_base.py @@ -60,7 +60,7 @@ class BaseInstallationValidator: # Subclasses can override or extend this CUDA_SUBMODULES = { "torch": ["torch.cuda"], - "sglang": ["sgl_kernel", "sgl_kernel.flash_attn"], + "sglang": ["sglang_kernel", "sglang_kernel.flash_attn"], "vllm": ["vllm._C"], "flash-attn": ["flash_attn_2_cuda"], "megatron-core": [ diff --git a/astraflow/train_worker/utils/fsdp/__init__.py b/astraflow/train_worker/utils/fsdp/__init__.py index d0dfbe6..211755c 100644 --- a/astraflow/train_worker/utils/fsdp/__init__.py +++ b/astraflow/train_worker/utils/fsdp/__init__.py @@ -64,6 +64,8 @@ def apply_fsdp2(model, fsdp_kwargs, wrap_policy): if isinstance(fsdp_transformer_layer_cls_to_wrap, str): fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] + else: + fsdp_transformer_layer_cls_to_wrap = list(fsdp_transformer_layer_cls_to_wrap) assert ( len(fsdp_transformer_layer_cls_to_wrap) > 0 diff --git a/docker/Dockerfile.sglang b/docker/Dockerfile.sglang index b2a1ef9..bbc821f 100644 --- a/docker/Dockerfile.sglang +++ b/docker/Dockerfile.sglang @@ -1,6 +1,6 @@ # Cache chain: basic → sglang # Layers up to "uv pip install -e ." are identical to Dockerfile.basic. -FROM nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04 +FROM nvidia/cuda:13.0.0-cudnn-devel-ubuntu24.04 SHELL ["/bin/bash", "-lc"] diff --git a/docker/README.md b/docker/README.md index f04c6dd..82e4eb3 100644 --- a/docker/README.md +++ b/docker/README.md @@ -13,7 +13,7 @@ | ------------------- | ------------------------------- | ---------------- | | `Dockerfile.sglang` | astraflow + SGLang + flash-attn | `-e ".[sglang]"` | -The image is based on `nvidia/cuda:12.9.1-cudnn-devel-ubuntu24.04` with Python 3.12 +The image is based on `nvidia/cuda:13.0.0-cudnn-devel-ubuntu24.04` with Python 3.12 managed by [uv](https://docs.astral.sh/uv/). ## Pull pre-built image diff --git a/docs/en/get-started/installation.md b/docs/en/get-started/installation.md index 621bfda..3ee45b9 100644 --- a/docs/en/get-started/installation.md +++ b/docs/en/get-started/installation.md @@ -19,31 +19,63 @@ conda activate astraflow ### Step 2: Install uv (fast pip replacement) ```bash -pip install uv +pip install -U "uv>=0.10" ``` +> **uv ≥ 0.10 is required.** `pyproject.toml` uses `[tool.uv]` settings +> (`extra-build-dependencies`, `override-dependencies`) that older uv +> releases don't recognize. When uv hits an unknown `[tool.uv]` key it +> silently ignores the *entire* `[tool.uv]` table, so the +> `transformers==5.6.1` override (which must beat sglang's `==5.6.0` pin) +> is dropped and the install fails with an unsolvable +> `transformers` conflict. The Docker images install the latest uv via the +> official installer and are unaffected. + ### Step 3: Install AstraFlow (core + dev tools) ```bash uv pip install -e ".[dev]" ``` -This installs all core dependencies (~260 packages) including PyTorch 2.8.0, -Transformers 4.57.1, Megatron-Core 0.13.1, Ray, W&B, and dev tools (pytest, ruff, +This installs all core dependencies (~260 packages) including PyTorch 2.11.0, +Transformers 5.6.1, Megatron-Core 0.13.1, Ray, W&B, and dev tools (pytest, ruff, ipython). ### Step 4: Install Flash Attention and SGLang #### Flash Attention +This is FlashAttention-**2** (`import flash_attn`), used by the FSDP trainer. It +is excluded from uv resolution (see `pyproject.toml` `[tool.uv]`) and built from +source, so it needs the CUDA 13 toolchain and a roomy build-temp directory: + ```bash +# nvcc must be on PATH and match torch's CUDA (13.0 for torch 2.11+cu130) +export CUDA_HOME=/usr/local/cuda-13.0 +export PATH="$CUDA_HOME/bin:$PATH" + +# nvcc writes GBs of intermediate files to $TMPDIR. Point it at local scratch +# with plenty of space — NOT a small/NFS-quota'd home, or the build fails with +# "nvFatbin error: empty input" or "Disk quota exceeded" from truncated temps. +export TMPDIR=/tmp/fa-build && mkdir -p "$TMPDIR" + uv pip install "flash-attn==2.8.3" --no-build-isolation ``` +> On a single-GPU-arch box you can speed up the build and shrink its footprint +> with `FLASH_ATTN_CUDA_ARCHS= NVCC_THREADS=1` (e.g. `90` for H100, `80` +> for A100, `89` for L40/4090). These are optional — the real requirement is a +> roomy `TMPDIR`. + #### SGLang (inference backend) +Install via the project extra so uv applies the `[tool.uv]` overrides (the +`transformers==5.6.1` pin and the `flash-attn-4` pre-release allowance). SGLang +pulls in FlashAttention-**4** (`flash-attn-4`, a pre-release wheel) automatically +for its own attention backend — you do not install that one yourself. + ```bash -uv pip install "sglang==0.5.5.post1" +uv pip install -e ".[sglang]" ``` ### Step 5: Verify installation diff --git a/pyproject.toml b/pyproject.toml index 80108de..1f6b4b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ # 4 - Beta # 5 - Production/Stable "Development Status :: 3 - Alpha", - "Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.9", + "Environment :: GPU :: NVIDIA CUDA :: 13", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", @@ -41,17 +41,17 @@ classifiers = [ dependencies = [ # Core ML/AI libraries - "torch==2.8.0", + "torch==2.11.0", "torchaudio", "torchvision", "torchdata", "huggingface_hub", "datasets>=3.0.0", - "transformers==4.57.1", + "transformers==5.6.1", "megatron-core==0.13.1", "mbridge==0.13.0", - "torch_memory_saver==0.0.9", - "peft", + "torch_memory_saver==0.0.9.post1", + "peft>=0.18.0", "qwen_agent", "openai-agents", @@ -140,7 +140,7 @@ te = [ ] sglang = [ - "sglang==0.5.5.post1", + "sglang==0.5.12.post1", ] vllm = [ @@ -192,13 +192,27 @@ include = ["astraflow*"] exclude = ["tests*", "docs*", "examples*"] [tool.uv] +# sglang 0.5.12 depends on flash-attn-4>=4.0.0b9 (a pre-release wheel, pulled +# in automatically as a sglang dependency). Without this, `uv pip install +# -e ".[sglang]"` fails to resolve with "pre-releases weren't enabled". +prerelease = "allow" exclude-dependencies = ["flash-attn"] override-dependencies=[ "outlines-core==0.1.26", + # sglang 0.5.12 pins transformers==5.6.0, which has a flash-attention bug + # (unconditional s_aux.to() crashes non-sink models like Qwen3). 5.6.1 is + # a patch release that fixes it; override sglang's exact pin to pick it up. + "transformers==5.6.1", + # sglang requires an unbounded "kernels", so uv resolves the latest (0.15+), + # but transformers 5.6.1 only supports kernels<0.13 (its hub_kernels module + # calls LayerRepository() without a revision/version, which 0.15 rejects -> + # `import sglang` crashes with "Either a revision or a version must be + # specified."). Pin to the range transformers 5.6.1 was built against. + "kernels>=0.12.0,<0.13", ] [tool.uv.extra-build-dependencies] -flash-attn = ["torch==2.8.0"] +flash-attn = ["torch==2.11.0"] [tool.pytest.ini_options] pythonpath = ["."]