diff --git a/backend_service/helpers/system.py b/backend_service/helpers/system.py index 5384851..a1468da 100644 --- a/backend_service/helpers/system.py +++ b/backend_service/helpers/system.py @@ -143,6 +143,13 @@ def _get_cache_strategies(): from cache_compression import registry return registry.available() + def _get_mtplx_info(): + from backend_service.inference._mtp import MTP_MODEL_MAP, _MTP_ALIASES + from backend_service.inference.capabilities import _detect_mtplx + available, _ = _detect_mtplx() + supported_models = list(MTP_MODEL_MAP.keys()) + list(_MTP_ALIASES.keys()) + return {"available": available, "supportedModels": supported_models} + def _get_dflash_info(): try: from dflash import availability_info @@ -173,6 +180,7 @@ def _get_dflash_info(): "appVersion": app_version, "availableCacheStrategies": _get_cache_strategies(), "dflash": _get_dflash_info(), + "mtplx": _get_mtplx_info(), "vllmAvailable": native.get("vllmAvailable", False), "vllmVersion": native.get("vllmVersion"), "mlxAvailable": native["mlxAvailable"], diff --git a/backend_service/inference/_mtp.py b/backend_service/inference/_mtp.py new file mode 100644 index 0000000..722cc2d --- /dev/null +++ b/backend_service/inference/_mtp.py @@ -0,0 +1,92 @@ +"""MTP (Multi-Token Prediction) model registry. + +Maps models that carry baked-in MTP heads to their recommended +``spec-draft-n-max`` count. Used by two separate inference paths: + +- MLX path → ``MtplxEngine`` (this module, ``has_mtp_heads``) +- GGUF path → ``LlamaCppEngine._build_command`` (``--spec-type mtp``, + Phase 2 of feature/mtplx) + +Only models trained with MTP objectives belong here — standard MLX quants +of base/instruct checkpoints that strip MTP heads at conversion time should +NOT be listed. MTPLX auto-detects MTP heads at load time; this map is used +by ChaosEngineAI to decide whether to offer the MTPLX toggle and, for the +GGUF path, how many draft tokens to request. +""" + +from __future__ import annotations + +# --------------------------------------------------------------------------- +# MLX / transformers repos with baked-in MTP heads +# --------------------------------------------------------------------------- +# +# Key → canonical HuggingFace repo id (case-sensitive) +# Value → recommended spec-draft-n-max (1–3); start conservatively at 1 +# and bump when acceptance rate benchmarks justify it. + +MTP_MODEL_MAP: dict[str, int] = { + # ----- Youssofal MTPLX-Optimized (upstream-verified for MTPLX v0.3.5) ----- + "Youssofal/Qwen3.6-27B-MTPLX-Optimized-Speed": 1, + "Youssofal/Qwen3.6-27B-MTPLX-Optimized-Speed-FP16": 1, + "Youssofal/Qwen3.6-27B-MTPLX-Optimized-Quality": 1, + # ----- Qwen3.5 family ----- + "Qwen/Qwen3.5-4B": 1, + "Qwen/Qwen3.5-7B": 1, + "Qwen/Qwen3.5-9B": 1, + "Qwen/Qwen3.5-14B": 1, + "Qwen/Qwen3.5-27B": 1, + "Qwen/Qwen3.5-35B-A3B": 1, + "Qwen/Qwen3.5-122B-A10B": 1, + # ----- Qwen3.6 family ----- + "Qwen/Qwen3.6-27B": 1, + "Qwen/Qwen3.6-35B-A3B": 1, + # ----- Qwen3-Coder-Next ----- + "Qwen/Qwen3-Coder-Next": 1, + # ----- DeepSeek V3 / R1 ----- + "deepseek-ai/DeepSeek-V3": 1, + "deepseek-ai/DeepSeek-V3-0324": 1, + "deepseek-ai/DeepSeek-R1": 1, +} + +# Community MLX conversions that preserve MTP heads. +# Maps community repo → canonical repo (for draft-n lookup). +_MTP_ALIASES: dict[str, str] = { + # Qwen3.5 + "mlx-community/Qwen3.5-4B-4bit": "Qwen/Qwen3.5-4B", + "mlx-community/Qwen3.5-4B-8bit": "Qwen/Qwen3.5-4B", + "mlx-community/Qwen3.5-7B-4bit": "Qwen/Qwen3.5-7B", + "mlx-community/Qwen3.5-7B-8bit": "Qwen/Qwen3.5-7B", + "mlx-community/Qwen3.5-9B-4bit": "Qwen/Qwen3.5-9B", + "mlx-community/Qwen3.5-9B-8bit": "Qwen/Qwen3.5-9B", + "mlx-community/Qwen3.5-14B-4bit": "Qwen/Qwen3.5-14B", + "mlx-community/Qwen3.5-14B-8bit": "Qwen/Qwen3.5-14B", + "mlx-community/Qwen3.5-27B-4bit": "Qwen/Qwen3.5-27B", + "mlx-community/Qwen3.5-27B-8bit": "Qwen/Qwen3.5-27B", + # Qwen3.6 + "mlx-community/Qwen3.6-27B-4bit": "Qwen/Qwen3.6-27B", + "mlx-community/Qwen3.6-27B-8bit": "Qwen/Qwen3.6-27B", + "mlx-community/Qwen3.6-27B-bf16": "Qwen/Qwen3.6-27B", + "mlx-community/Qwen3.6-35B-A3B-4bit": "Qwen/Qwen3.6-35B-A3B", + "lmstudio-community/Qwen3.6-27B-GGUF": "Qwen/Qwen3.6-27B", + # Qwen3-Coder-Next + "lmstudio-community/Qwen3-Coder-Next-MLX-4bit": "Qwen/Qwen3-Coder-Next", +} + + +def get_mtp_draft_n(repo: str) -> int | None: + """Return the recommended spec-draft-n-max for *repo*, or None. + + Returns None when the repo is not known to carry MTP heads — callers + should not enable MTP speculative decoding for that model. + """ + if repo in MTP_MODEL_MAP: + return MTP_MODEL_MAP[repo] + canonical = _MTP_ALIASES.get(repo) + if canonical: + return MTP_MODEL_MAP.get(canonical) + return None + + +def has_mtp_heads(repo: str) -> bool: + """True when *repo* (or a community alias of it) carries baked-in MTP heads.""" + return get_mtp_draft_n(repo) is not None diff --git a/backend_service/inference/base.py b/backend_service/inference/base.py index fcbb989..d9e0648 100644 --- a/backend_service/inference/base.py +++ b/backend_service/inference/base.py @@ -90,6 +90,8 @@ class BackendCapabilities: converterAvailable: bool = False vllmAvailable: bool = False vllmVersion: str | None = None + mtplxAvailable: bool = False + mtplxPythonPath: str | None = None probing: bool = False def to_dict(self) -> dict[str, Any]: @@ -108,6 +110,8 @@ def to_dict(self) -> dict[str, Any]: "converterAvailable": self.converterAvailable, "vllmAvailable": self.vllmAvailable, "vllmVersion": self.vllmVersion, + "mtplxAvailable": self.mtplxAvailable, + "mtplxPythonPath": self.mtplxPythonPath, "probing": self.probing, } diff --git a/backend_service/inference/capabilities.py b/backend_service/inference/capabilities.py index 8e4c3b6..48eb4b2 100644 --- a/backend_service/inference/capabilities.py +++ b/backend_service/inference/capabilities.py @@ -13,6 +13,8 @@ import time from threading import RLock +from pathlib import Path + from backend_service.inference._constants import CAPABILITY_CACHE_TTL_SECONDS from backend_service.inference.base import BackendCapabilities from backend_service.inference.binaries import ( @@ -24,10 +26,26 @@ ) +_MTPLX_VENV = Path.home() / ".chaosengine" / "mtplx-venv" +_MTPLX_VERSION_FILE = Path.home() / ".chaosengine" / "bin" / "mtplx.version" + _capability_cache: tuple[float, BackendCapabilities] | None = None _capability_lock = RLock() +def _detect_mtplx() -> tuple[bool, str | None]: + """Return (available, python_path) for the MTPLX isolated venv. + + Cheap file-existence check — no subprocess spawn. The version file is + written by install-mtplx.sh on clean install; its presence together with + the venv python binary is sufficient to confirm a usable install. + """ + python = _MTPLX_VENV / "bin" / "python" + if _MTPLX_VERSION_FILE.exists() and python.exists(): + return True, str(python) + return False, None + + def _initial_backend_capabilities() -> BackendCapabilities: """Cheap capability placeholder used while the real probe runs. @@ -40,6 +58,7 @@ def _initial_backend_capabilities() -> BackendCapabilities: llama_server_path = _resolve_llama_server() llama_server_turbo_path = _resolve_llama_server_turbo() llama_cli_path = _resolve_llama_cli() + mtplx_available, mtplx_python = _detect_mtplx() return BackendCapabilities( pythonExecutable=python_executable, mlxAvailable=False, @@ -53,6 +72,8 @@ def _initial_backend_capabilities() -> BackendCapabilities: converterAvailable=False, vllmAvailable=False, vllmVersion=None, + mtplxAvailable=mtplx_available, + mtplxPythonPath=mtplx_python, probing=True, ) @@ -80,6 +101,7 @@ def _probe_native_backends() -> BackendCapabilities: from backend_service.vllm_engine import _vllm_importable, _vllm_version + mtplx_available, mtplx_python = _detect_mtplx() return BackendCapabilities( pythonExecutable=python_executable, mlxAvailable=mlx_available, @@ -95,6 +117,8 @@ def _probe_native_backends() -> BackendCapabilities: converterAvailable=mlx_usable, vllmAvailable=_vllm_importable(), vllmVersion=_vllm_version(), + mtplxAvailable=mtplx_available, + mtplxPythonPath=mtplx_python, ) diff --git a/backend_service/inference/controller.py b/backend_service/inference/controller.py index 13f8843..fd95703 100644 --- a/backend_service/inference/controller.py +++ b/backend_service/inference/controller.py @@ -87,6 +87,7 @@ LlamaCppEngine, _CACHE_TYPE_CACHE, _LLAMA_HELP_CACHE, + _LLAMA_HELP_LOCK, _LLAMA_SAMPLER_KEYS, _STANDARD_CACHE_TYPES, _apply_llama_chat_template_fixes, @@ -99,10 +100,12 @@ _resolve_mmproj_path, ) from backend_service.inference.mlx_engine import MLXWorkerEngine +from backend_service.inference.mtplx_engine import MtplxEngine from backend_service.inference.simple_engines import ( MockInferenceEngine, RemoteOpenAIEngine, ) +from backend_service.inference._mtp import has_mtp_heads class RuntimeController: @@ -338,7 +341,8 @@ def prune_stale_backend_children(self) -> None: is_mlx_worker = "backend_service.mlx_worker" in cmdline or "mlx_worker" in cmdline is_llama = "llama-server" in name or "llama-server" in cmdline - if not (is_mlx_worker or is_llama): + is_mtplx = "mtplx" in name or "/mtplx-venv/" in cmdline or " mtplx " in f" {cmdline} " or cmdline.endswith(" mtplx") + if not (is_mlx_worker or is_llama or is_mtplx): continue if child.pid in tracked: continue @@ -350,10 +354,16 @@ def prune_stale_backend_children(self) -> None: if now_wall - create_time < self.ORPHAN_DETECTION_GRACE_SECONDS: continue + if is_mlx_worker: + kind, label = "mlx_worker", "MLX worker" + elif is_mtplx: + kind, label = "mtplx", "MTPLX server" + else: + kind, label = "llama_server", "llama-server" record = { "pid": int(child.pid), - "kind": "mlx_worker" if is_mlx_worker else "llama_server", - "label": "MLX worker" if is_mlx_worker else "llama-server", + "kind": kind, + "label": label, "action": "terminated", "detectedAt": _now_label(), # Internal monotonic stamp used for TTL; not serialized. @@ -475,6 +485,9 @@ def _select_engine( backend: str, runtime_target: str | None, path: str | None, + model_ref: str = "", + canonical_repo: str | None = None, + speculative_decoding: bool = False, ) -> BaseInferenceEngine: hint = (backend or "auto").lower() target = runtime_target or path @@ -483,6 +496,12 @@ def _select_engine( return RemoteOpenAIEngine(self.capabilities) if hint == "mlx": if self.capabilities.mlxUsable: + if ( + speculative_decoding + and self.capabilities.mtplxAvailable + and has_mtp_heads(canonical_repo or model_ref) + ): + return MtplxEngine(self.capabilities) return MLXWorkerEngine(self.capabilities) reason = self.capabilities.mlxMessage or "MLX is not available in this environment" raise RuntimeError( @@ -727,6 +746,9 @@ def _internal_progress(progress: dict[str, Any]) -> None: backend=backend, runtime_target=runtime_target, path=path, + model_ref=model_ref, + canonical_repo=canonical_repo, + speculative_decoding=speculative_decoding, ) # Never keep multiple warm copies of the same logical model under @@ -739,25 +761,52 @@ def _internal_progress(progress: dict[str, Any]) -> None: ) self.engine = selected_engine + _load_kwargs: dict[str, Any] = dict( + model_ref=model_ref, + model_name=resolved_name, + canonical_repo=canonical_repo, + source=source, + backend=self.engine.engine_name, + path=path, + runtime_target=runtime_target, + cache_strategy=cache_strategy, + cache_bits=cache_bits, + fp16_layers=fp16_layers, + fused_attention=fused_attention, + fit_model_in_memory=fit_model_in_memory, + context_tokens=context_tokens, + speculative_decoding=speculative_decoding, + tree_budget=tree_budget, + progress_callback=_internal_progress, + ) try: - loaded = self.engine.load_model( - model_ref=model_ref, - model_name=resolved_name, - canonical_repo=canonical_repo, - source=source, - backend=self.engine.engine_name, - path=path, - runtime_target=runtime_target, - cache_strategy=cache_strategy, - cache_bits=cache_bits, - fp16_layers=fp16_layers, - fused_attention=fused_attention, - fit_model_in_memory=fit_model_in_memory, - context_tokens=context_tokens, - speculative_decoding=speculative_decoding, - tree_budget=tree_budget, - progress_callback=_internal_progress, - ) + loaded = self.engine.load_model(**_load_kwargs) + except RuntimeError as _mtplx_exc: + # MtplxEngine startup failure → fall back to standard MLX worker. + if isinstance(self.engine, MtplxEngine): + fallback = MLXWorkerEngine(self.capabilities) + self.engine = fallback + _load_kwargs["backend"] = fallback.engine_name + try: + loaded = self.engine.load_model(**_load_kwargs) + loaded.runtimeNote = _append_runtime_note( + loaded.runtimeNote, + f"MTPLX startup failed ({_mtplx_exc}); using standard MLX.", + ) + except Exception: + self.loaded_model = None + self.runtime_note = None + self._loading_progress = None + self._loading_log_tail = [] + self.prune_stale_backend_children() + raise + else: + self.loaded_model = None + self.runtime_note = None + self._loading_progress = None + self._loading_log_tail = [] + self.prune_stale_backend_children() + raise except Exception: self.loaded_model = None self.runtime_note = None diff --git a/backend_service/inference/mtplx_engine.py b/backend_service/inference/mtplx_engine.py new file mode 100644 index 0000000..ba5e836 --- /dev/null +++ b/backend_service/inference/mtplx_engine.py @@ -0,0 +1,432 @@ +"""MTPLX inference engine. + +Spawns the MTPLX server (``mtplx start --model --port N``) from its +isolated venv at ``~/.chaosengine/mtplx-venv/`` as a subprocess, then proxies +``/v1/chat/completions`` through it — the same pattern used by +``LlamaCppEngine`` for llama-server. + +MTPLX provides native in-model MTP speculative decoding for Apple Silicon; +its forked mlx lives in the isolated venv so it never conflicts with the main +``.venv``'s upstream mlx. + +Fallback contract: ``load_model`` raises ``RuntimeError`` on any startup +failure. The ``RuntimeController`` catches that and falls back to the +standard ``MlxEngine``. +""" + +from __future__ import annotations + +import json +import subprocess +import tempfile +import time +import urllib.error +import urllib.request +from collections.abc import Callable, Iterator +from pathlib import Path +from typing import Any + +from backend_service.inference._constants import ( + DEFAULT_LLAMA_TIMEOUT_SECONDS, + WORKSPACE_ROOT, +) +from backend_service.inference._utils import ( + _append_runtime_note, + _find_open_port, + _http_json, + _normalize_message_content, + _now_label, + _read_text_tail, +) +from backend_service.inference.base import ( + BackendCapabilities, + BaseInferenceEngine, + GenerationResult, + LoadedModelInfo, + RepeatedLineGuard, + StreamChunk, +) +from backend_service.reasoning_split import ( + ThinkingTokenFilter, + strip_thinking_tokens as _strip_thinking_tokens, +) + +_MTPLX_VENV = Path.home() / ".chaosengine" / "mtplx-venv" + + +class MtplxEngine(BaseInferenceEngine): + engine_name = "mtplx" + engine_label = "MTPLX (MTP speculative decoding)" + + def __init__(self, capabilities: BackendCapabilities) -> None: + self.capabilities = capabilities + self.loaded_model: LoadedModelInfo | None = None + self.process: subprocess.Popen[str] | None = None + self.port: int | None = None + self.log_path: Path | None = None + self.log_handle: Any = None + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _server_url(self, path: str) -> str: + if self.port is None: + raise RuntimeError("MTPLX server is not running.") + return f"http://127.0.0.1:{self.port}{path}" + + def _mtplx_bin(self) -> str: + """Path to the mtplx executable in the isolated venv.""" + candidate = _MTPLX_VENV / "bin" / "mtplx" + if candidate.exists(): + return str(candidate) + # Fall back to capabilities-resolved python path's sibling + if self.capabilities.mtplxPythonPath: + sibling = Path(self.capabilities.mtplxPythonPath).parent / "mtplx" + if sibling.exists(): + return str(sibling) + raise RuntimeError( + "MTPLX is not installed. Install it from the Setup tab." + ) + + def _cleanup_process(self) -> None: + if self.process is not None and self.process.poll() is None: + try: + self.process.terminate() + except (ProcessLookupError, OSError): + pass + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + try: + self.process.kill() + except (ProcessLookupError, OSError): + pass + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + self.process = None + self.port = None + if self.log_handle is not None: + try: + self.log_handle.close() + except OSError: + pass + self.log_handle = None + + def process_pid(self) -> int | None: + if self.process is None or self.process.poll() is not None: + return None + return int(self.process.pid) + + def _wait_for_server(self) -> None: + deadline = time.time() + DEFAULT_LLAMA_TIMEOUT_SECONDS + last_error = "MTPLX server did not become ready." + while time.time() < deadline: + if self.process is not None and self.process.poll() is not None: + logs = _read_text_tail(self.log_path) + raise RuntimeError(logs or "MTPLX server exited during startup.") + try: + _http_json(self._server_url("/health"), timeout=2.0) + return + except Exception as exc: + last_error = str(exc) + time.sleep(1.0) + logs = _read_text_tail(self.log_path) + raise RuntimeError(logs if logs else last_error) + + # ------------------------------------------------------------------ + # BaseInferenceEngine interface + # ------------------------------------------------------------------ + + def load_model( + self, + *, + model_ref: str, + model_name: str, + canonical_repo: str | None, + source: str, + backend: str, + path: str | None, + runtime_target: str | None, + cache_strategy: str, + cache_bits: int, + fp16_layers: int, + fused_attention: bool, + fit_model_in_memory: bool, + context_tokens: int, + speculative_decoding: bool = True, + tree_budget: int = 0, + progress_callback: Callable[[dict[str, Any]], None] | None = None, + ) -> LoadedModelInfo: + if not self.capabilities.mtplxAvailable: + raise RuntimeError("MTPLX is not installed. Install it from the Setup tab.") + + self.unload_model() + + mtplx_bin = self._mtplx_bin() + self.port = _find_open_port() + + # Prefer local path; fall back to HF repo id (MTPLX will download). + model_arg = path or runtime_target or model_ref + + command = [ + mtplx_bin, + "start", + "--model", model_arg, + "--port", str(self.port), + ] + + temp_log = tempfile.NamedTemporaryFile( + prefix="chaosengine-mtplx-", suffix=".log", delete=False + ) + temp_log.close() + self.log_path = Path(temp_log.name) + self.log_handle = self.log_path.open("a", encoding="utf-8") + + self.process = subprocess.Popen( + command, + cwd=str(WORKSPACE_ROOT), + stdout=self.log_handle, + stderr=self.log_handle, + text=True, + ) + + try: + self._wait_for_server() + except RuntimeError: + self._cleanup_process() + raise + + from backend_service.inference._mtp import get_mtp_draft_n + draft_n = get_mtp_draft_n(canonical_repo or model_ref) or 1 + + runtime_note = ( + f"MTPLX MTP speculative decoding active " + f"(draft tokens: {draft_n}, model: {model_name})." + ) + + self.loaded_model = LoadedModelInfo( + ref=model_ref, + name=model_name, + canonicalRepo=canonical_repo, + backend=backend, + source=source, + engine=self.engine_name, + cacheStrategy=cache_strategy, + cacheBits=0, + fp16Layers=0, + fusedAttention=False, + fitModelInMemory=fit_model_in_memory, + contextTokens=context_tokens, + loadedAt=_now_label(), + path=path, + runtimeTarget=runtime_target or path, + runtimeNote=runtime_note, + speculativeDecoding=True, + ) + return self.loaded_model + + def unload_model(self) -> None: + self._cleanup_process() + self.loaded_model = None + + def generate( + self, + *, + prompt: str, + history: list[dict[str, Any]], + system_prompt: str | None, + max_tokens: int, + temperature: float, + images: list[str] | None = None, + tools: list[dict[str, Any]] | None = None, + samplers: dict[str, Any] | None = None, + reasoning_effort: str | None = None, + json_schema: dict[str, Any] | None = None, + ) -> GenerationResult: + if self.loaded_model is None: + raise RuntimeError("No model is loaded.") + if self.process is None or self.process.poll() is not None: + logs = _read_text_tail(self.log_path) + raise RuntimeError(logs or "The MTPLX server is not running.") + + messages: list[dict[str, Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + for message in history: + role = message.get("role") + if role not in {"system", "user", "assistant", "tool"}: + continue + messages.append({"role": role, "content": _normalize_message_content(message.get("text", ""))}) + messages.append({"role": "user", "content": prompt}) + + started_at = time.perf_counter() + payload: dict[str, Any] = { + "model": self.loaded_model.ref, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": False, + } + if tools: + payload["tools"] = tools + + try: + response = _http_json( + self._server_url("/v1/chat/completions"), + payload=payload, + timeout=DEFAULT_LLAMA_TIMEOUT_SECONDS, + ) + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="ignore") + raise RuntimeError(detail or str(exc)) from exc + except urllib.error.URLError as exc: + raise RuntimeError(str(exc.reason)) from exc + + elapsed = max(time.perf_counter() - started_at, 1e-6) + choice = (response.get("choices") or [{}])[0] + message = choice.get("message") or {} + usage = response.get("usage") or {} + completion_tokens = int(usage.get("completion_tokens") or 0) + prompt_tokens = int(usage.get("prompt_tokens") or 0) + text = _strip_thinking_tokens(str(message.get("content") or "")) + + return GenerationResult( + text=text, + finishReason=str(choice.get("finish_reason") or "stop"), + promptTokens=prompt_tokens, + completionTokens=completion_tokens, + totalTokens=int(usage.get("total_tokens") or (prompt_tokens + completion_tokens)), + tokS=round(completion_tokens / elapsed, 1) if completion_tokens else 0.0, + responseSeconds=round(elapsed, 2), + runtimeNote=self.loaded_model.runtimeNote, + ) + + def stream_generate( + self, + *, + prompt: str, + history: list[dict[str, Any]], + system_prompt: str | None, + max_tokens: int, + temperature: float, + images: list[str] | None = None, + tools: list[dict[str, Any]] | None = None, + thinking_mode: str | None = None, + samplers: dict[str, Any] | None = None, + reasoning_effort: str | None = None, + json_schema: dict[str, Any] | None = None, + ) -> Iterator[StreamChunk]: + if self.loaded_model is None: + raise RuntimeError("No model is loaded.") + if self.process is None or self.process.poll() is not None: + logs = _read_text_tail(self.log_path) + raise RuntimeError(logs or "The MTPLX server is not running.") + + messages: list[dict[str, Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + for message in history: + role = message.get("role") + if role not in {"system", "user", "assistant", "tool"}: + continue + messages.append({"role": role, "content": _normalize_message_content(message.get("text", ""))}) + messages.append({"role": "user", "content": prompt}) + + payload: dict[str, Any] = { + "model": self.loaded_model.ref, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, + } + if tools: + payload["tools"] = tools + + url = self._server_url("/v1/chat/completions") + data = json.dumps(payload).encode("utf-8") + headers = {"Content-Type": "application/json", "Accept": "text/event-stream"} + request = urllib.request.Request(url, data=data, headers=headers, method="POST") + try: + resp = urllib.request.urlopen(request, timeout=DEFAULT_LLAMA_TIMEOUT_SECONDS) + except urllib.error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="ignore") + raise RuntimeError(detail or str(exc)) from exc + except urllib.error.URLError as exc: + raise RuntimeError(str(exc.reason)) from exc + + finish_reason = "stop" + prompt_tokens = 0 + completion_tokens = 0 + stream_start = time.perf_counter() + first_token_time: float | None = None + runtime_note = self.loaded_model.runtimeNote + think_filter = ThinkingTokenFilter(detect_raw_reasoning=(thinking_mode or "off") != "off") + runaway_guard = RepeatedLineGuard() + + try: + for raw_line in resp: + line = raw_line.decode("utf-8", errors="ignore").strip() + if not line or not line.startswith("data: "): + continue + payload_str = line[len("data: "):] + if payload_str == "[DONE]": + break + try: + chunk = json.loads(payload_str) + except json.JSONDecodeError: + continue + choice = (chunk.get("choices") or [{}])[0] + delta = choice.get("delta") or {} + content = delta.get("content") + if content: + split = think_filter.feed(str(content)) + if split.reasoning: + yield StreamChunk(reasoning=split.reasoning) + if split.reasoning_done: + yield StreamChunk(reasoning_done=True) + if split.text: + runaway_guard.feed(split.text) + if first_token_time is None: + first_token_time = time.perf_counter() + completion_tokens += 1 + yield StreamChunk(text=split.text) + fr = choice.get("finish_reason") + if fr: + finish_reason = fr + usage = chunk.get("usage") + if usage: + prompt_tokens = int(usage.get("prompt_tokens") or 0) + completion_tokens = int(usage.get("completion_tokens") or completion_tokens) + flushed = think_filter.flush() + if flushed.reasoning: + yield StreamChunk(reasoning=flushed.reasoning) + if flushed.reasoning_done: + yield StreamChunk(reasoning_done=True) + if flushed.text: + runaway_guard.feed(flushed.text) + if first_token_time is None: + first_token_time = time.perf_counter() + yield StreamChunk(text=flushed.text) + runaway_guard.flush() + except RuntimeError as exc: + runtime_note = _append_runtime_note(runtime_note, str(exc)) + finish_reason = "stop" + finally: + resp.close() + + end_time = time.perf_counter() + gen_elapsed = max(end_time - (first_token_time or stream_start), 1e-6) + tok_s = round(completion_tokens / gen_elapsed, 1) if completion_tokens > 0 else 0.0 + + yield StreamChunk( + done=True, + finish_reason=finish_reason, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + tok_s=tok_s, + runtime_note=runtime_note, + ) diff --git a/backend_service/routes/setup/__init__.py b/backend_service/routes/setup/__init__.py index 55c91b6..a82f6e6 100644 --- a/backend_service/routes/setup/__init__.py +++ b/backend_service/routes/setup/__init__.py @@ -346,11 +346,13 @@ def refresh_capabilities_endpoint(request: Request) -> dict[str, Any]: ) from backend_service.routes.setup.gpu_bundle import router as _gpu_bundle_router from backend_service.routes.setup.longlive import router as _longlive_router +from backend_service.routes.setup.mtplx import router as _mtplx_router from backend_service.routes.setup.turbo import router as _turbo_router from backend_service.routes.setup.wan_install import router as _wan_install_router router.include_router(_cuda_torch_router) router.include_router(_gpu_bundle_router) router.include_router(_longlive_router) +router.include_router(_mtplx_router) router.include_router(_turbo_router) router.include_router(_wan_install_router) diff --git a/backend_service/routes/setup/mtplx.py b/backend_service/routes/setup/mtplx.py new file mode 100644 index 0000000..d435e5c --- /dev/null +++ b/backend_service/routes/setup/mtplx.py @@ -0,0 +1,204 @@ +"""MTPLX install and status endpoints. + +Background-job pattern: a single in-memory ``_MtplxJobState`` tracks the +running install. POST starts a daemon thread; GET polls. A second POST while +the job is running returns the running state rather than starting a new job. + +Phases driven by ``scripts/install-mtplx.sh`` PHASE: markers: + idle → preflight → creating-venv → installing → verifying → done | error + +The ``/api/setup/mtplx-status`` endpoint is a lightweight probe that checks +whether MTPLX is installed (version file + import smoke-test) without +triggering a full install. Used by RuntimeControls to decide whether to show +the MTPLX toggle or the install chip. +""" + +from __future__ import annotations + +import subprocess +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from fastapi import APIRouter + +router = APIRouter() + +_CHAOSENGINE_BIN_DIR = Path.home() / ".chaosengine" / "bin" +_MTPLX_VENV_DIR = Path.home() / ".chaosengine" / "mtplx-venv" +_MTPLX_VERSION_FILE = _CHAOSENGINE_BIN_DIR / "mtplx.version" +_INSTALL_SCRIPT = Path(__file__).parents[3] / "scripts" / "install-mtplx.sh" + +_PHASE_LABELS: dict[str, str] = { + "preflight": "Checking Python environment", + "creating-venv": "Creating isolated venv", + "installing": "Installing MTPLX", + "verifying": "Verifying install", +} + +_TOTAL_PHASES = len(_PHASE_LABELS) + + +@dataclass +class _MtplxJobState: + phase: str = "idle" + message: str = "" + package_current: str | None = None + package_index: int = 0 + package_total: int = _TOTAL_PHASES + percent: float = 0.0 + target_dir: str | None = None + error: str | None = None + started_at: float = 0.0 + finished_at: float = 0.0 + attempts: list[dict[str, Any]] = field(default_factory=list) + done: bool = False + + def to_dict(self) -> dict[str, Any]: + return { + "id": "mtplx-install", + "phase": self.phase, + "message": self.message, + "packageCurrent": self.package_current, + "packageIndex": self.package_index, + "packageTotal": self.package_total, + "percent": round(self.percent, 1), + "targetDir": self.target_dir, + "error": self.error, + "startedAt": self.started_at, + "finishedAt": self.finished_at, + "attempts": list(self.attempts), + "done": self.done, + } + + +_JOB = _MtplxJobState() +_JOB_LOCK = threading.Lock() + + +def _read_version() -> tuple[str | None, str | None]: + """Return (version, installed_at) from the version file, or (None, None).""" + if not _MTPLX_VERSION_FILE.exists(): + return None, None + try: + lines = _MTPLX_VERSION_FILE.read_text().strip().splitlines() + version = lines[0].strip() if lines else None + installed_at = lines[1].strip() if len(lines) > 1 else None + return version, installed_at + except OSError: + return None, None + + +def _is_installed() -> bool: + python = _MTPLX_VENV_DIR / "bin" / "python" + return _MTPLX_VERSION_FILE.exists() and python.exists() + + +def _job_worker() -> None: + """Run install-mtplx.sh and stream output into job state.""" + job = _JOB + phase_buffer: list[str] = [] + phase_index = 0 + + def push_attempt(phase: str, ok: bool) -> None: + job.attempts.append({ + "phase": phase, + "package": _PHASE_LABELS.get(phase, phase), + "ok": ok, + "output": "\n".join(phase_buffer)[-8000:], + }) + phase_buffer.clear() + + def advance_phase(name: str) -> None: + nonlocal phase_index + if job.phase not in ("idle", "preflight", "creating-venv", "installing", "verifying"): + return + if phase_index > 0: + push_attempt(job.phase, ok=True) + phase_index += 1 + job.phase = name + job.package_current = _PHASE_LABELS.get(name, name) + job.package_index = phase_index + job.percent = round((phase_index - 1) / _TOTAL_PHASES * 100, 1) + + try: + proc = subprocess.Popen( + ["bash", str(_INSTALL_SCRIPT)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + for raw_line in proc.stdout: # type: ignore[union-attr] + line = raw_line.rstrip("\n") + if line.startswith("PHASE:"): + advance_phase(line[len("PHASE:"):].strip()) + elif line.startswith("FAIL:"): + job.error = line[len("FAIL:"):].strip() or "Install failed" + phase_buffer.append(line) + else: + phase_buffer.append(line) + if len(phase_buffer) > 400: + del phase_buffer[: len(phase_buffer) - 400] + + proc.wait() + + if proc.returncode == 0 and not job.error: + push_attempt(job.phase, ok=True) + job.phase = "done" + job.percent = 100.0 + version, _ = _read_version() + job.message = f"MTPLX {version or 'installed'} ready in {_MTPLX_VENV_DIR}" + job.done = True + else: + push_attempt(job.phase, ok=False) + job.phase = "error" + job.error = job.error or f"install-mtplx.sh exited with code {proc.returncode}" + job.done = True + + except Exception as exc: # noqa: BLE001 + push_attempt(job.phase, ok=False) + job.phase = "error" + job.error = str(exc) + job.done = True + finally: + job.finished_at = time.time() + + +@router.get("/api/setup/mtplx-status") +def mtplx_status() -> dict[str, Any]: + """Lightweight probe: is MTPLX installed and what version?""" + installed = _is_installed() + version, installed_at = _read_version() + return { + "installed": installed, + "version": version, + "installedAt": installed_at, + "venvPath": str(_MTPLX_VENV_DIR) if installed else None, + } + + +@router.post("/api/setup/install-mtplx") +def start_mtplx_install() -> dict[str, Any]: + """Start the MTPLX install job. Returns immediately; poll status endpoint.""" + with _JOB_LOCK: + if _JOB.phase not in ("idle", "done", "error"): + return _JOB.to_dict() + _JOB.__init__() # type: ignore[misc] + _JOB.phase = "preflight" + _JOB.started_at = time.time() + _JOB.target_dir = str(_MTPLX_VENV_DIR) + _JOB.package_current = _PHASE_LABELS["preflight"] + + thread = threading.Thread(target=_job_worker, daemon=True) + thread.start() + return _JOB.to_dict() + + +@router.get("/api/setup/install-mtplx/status") +def mtplx_install_status() -> dict[str, Any]: + """Poll the running install job.""" + return _JOB.to_dict() diff --git a/scripts/install-mtplx.sh b/scripts/install-mtplx.sh new file mode 100755 index 0000000..ade5e2f --- /dev/null +++ b/scripts/install-mtplx.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +# Install MTPLX into an isolated venv at ~/.chaosengine/mtplx-venv/. +# +# Requires native arm64 Python 3.10+ (MTPLX's forked mlx won't build under +# Rosetta). The script prints structured progress lines so the backend job +# worker can parse phase transitions: +# +# PHASE: — emitted before each phase starts +# OK — emitted on clean exit +# FAIL: — emitted before a non-zero exit +# +# The backend worker (routes/setup/mtplx.py) reads these markers to drive +# the InstallLogPanel phases without scraping pip output. + +set -euo pipefail + +VENV_DIR="${HOME}/.chaosengine/mtplx-venv" +BIN_DIR="${HOME}/.chaosengine/bin" +VERSION_FILE="${BIN_DIR}/mtplx.version" +MTPLX_PACKAGE="mtplx" + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +log() { echo "$*"; } +phase() { echo "PHASE:$1"; } +fail() { echo "FAIL:$*"; exit 1; } + +# --------------------------------------------------------------------------- +# Preflight — verify native arm64 Python 3.10+ +# --------------------------------------------------------------------------- + +phase "preflight" + +PYTHON="${PYTHON:-python3}" + +ARCH=$(${PYTHON} -c "import platform; print(platform.machine())" 2>/dev/null || true) +if [[ "${ARCH}" != "arm64" ]]; then + fail "MTPLX requires native arm64 Python (got: ${ARCH:-unknown}). Make sure you are not running under Rosetta." +fi + +PY_VER=$(${PYTHON} -c "import sys; print('%d.%d' % sys.version_info[:2])" 2>/dev/null || true) +PY_MAJ=$(echo "${PY_VER}" | cut -d. -f1) +PY_MIN=$(echo "${PY_VER}" | cut -d. -f2) +if [[ "${PY_MAJ}" -lt 3 ]] || { [[ "${PY_MAJ}" -eq 3 ]] && [[ "${PY_MIN}" -lt 10 ]]; }; then + fail "MTPLX requires Python 3.10+ (got: ${PY_VER})" +fi + +log "Python ${PY_VER} (arm64) — OK" +mkdir -p "${BIN_DIR}" + +# --------------------------------------------------------------------------- +# Create isolated venv +# --------------------------------------------------------------------------- + +phase "creating-venv" + +if [[ -d "${VENV_DIR}" ]]; then + log "Removing existing venv at ${VENV_DIR}" + rm -rf "${VENV_DIR}" +fi + +log "Creating venv at ${VENV_DIR}" +${PYTHON} -m venv "${VENV_DIR}" +log "Upgrading pip" +"${VENV_DIR}/bin/pip" install --quiet --upgrade pip + +# --------------------------------------------------------------------------- +# Install MTPLX (pulls in its mlx fork automatically) +# --------------------------------------------------------------------------- + +phase "installing" + +log "Installing ${MTPLX_PACKAGE}" +"${VENV_DIR}/bin/pip" install --upgrade "${MTPLX_PACKAGE}" + +# --------------------------------------------------------------------------- +# Verify: import check + extract version +# --------------------------------------------------------------------------- + +phase "verifying" + +MTPLX_VERSION=$("${VENV_DIR}/bin/pip" show mtplx 2>/dev/null \ + | grep -i "^Version:" | awk '{print $2}' || echo "unknown") + +IMPORT_OK=$("${VENV_DIR}/bin/python" -c "import mtplx; print('ok')" 2>/dev/null || echo "fail") +if [[ "${IMPORT_OK}" != "ok" ]]; then + fail "MTPLX import check failed — installation may be incomplete" +fi + +log "MTPLX ${MTPLX_VERSION} import verified" + +# --------------------------------------------------------------------------- +# Write version file +# --------------------------------------------------------------------------- + +{ + echo "${MTPLX_VERSION}" + echo "$(date -u +%Y-%m-%dT%H:%M:%SZ)" +} > "${VERSION_FILE}" + +log "Version file written to ${VERSION_FILE}" +echo "OK" diff --git a/src/App.tsx b/src/App.tsx index 35f977d..e28a5da 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -105,6 +105,7 @@ import { useDetailsWindowResize, useFileActions, } from "./hooks"; +import { useMtplxInstall } from "./hooks/useMtplxInstall"; export default function App() { // FU-042: i18n hook — used for the workspace header tab label / @@ -149,6 +150,11 @@ export default function App() { // ── Settings / Server / Preview ──────────────────────────── const imgState = useImageState(backendOnline, setError, setActiveTab); const videoState = useVideoState(backendOnline, setError, setActiveTab); + const { + mtplxJob, + installingMtplx, + handleInstallMtplx, + } = useMtplxInstall(); const { installingCudaTorch, @@ -237,6 +243,7 @@ export default function App() { activeDownloads, discoverCapFilter, setDiscoverCapFilter, discoverFormatFilter, setDiscoverFormatFilter, + discoverAccelFilter, setDiscoverAccelFilter, handleDownloadModel, handleCancelModelDownload, handleDeleteModelDownload, @@ -1127,6 +1134,13 @@ export default function App() { onDiscoverCapFilterChange={setDiscoverCapFilter} discoverFormatFilter={discoverFormatFilter} onDiscoverFormatFilterChange={setDiscoverFormatFilter} + discoverAccelFilter={discoverAccelFilter} + onDiscoverAccelFilterChange={setDiscoverAccelFilter} + accelCompat={{ + dflashModels: workspace.system.dflash?.supportedModels ?? [], + mtplxModels: workspace.system.mtplx?.supportedModels ?? [], + turboInstalled: Boolean(workspace.system.llamaServerTurboPath), + }} expandedFamilyId={expandedFamilyId} onExpandedFamilyIdChange={setExpandedFamilyId} expandedVariantId={expandedVariantId} @@ -1167,6 +1181,8 @@ export default function App() { turboInstalled: !!workspace.system.llamaServerTurboPath, turboquantMlxAvailable: workspace.system.availableCacheStrategies?.some((s) => s.id === "turboquant" && s.available) ?? false, dflashSupportedModels: workspace.system.dflash?.supportedModels ?? [], + mtplxInstalled: workspace.system.mtplx?.available ?? false, + mtplxSupportedModels: workspace.system.mtplx?.supportedModels ?? [], }} activeDownloads={activeDownloads} expandedLibraryPath={expandedLibraryPath} @@ -1569,7 +1585,6 @@ export default function App() { onSetError={setError} enableTools={chat.enableTools} onToggleTools={chat.setEnableTools} - onCompareMode={() => setActiveTab("chat-compare")} onCancelGeneration={chat.cancelGeneration} oneTurnOverride={chat.oneTurnOverride} onOneTurnOverrideChange={chat.setOneTurnOverride} @@ -1588,6 +1603,10 @@ export default function App() { availableCacheStrategies={workspace.system.availableCacheStrategies} dflashInfo={workspace.system.dflash} turboInstalled={Boolean(workspace.system.llamaServerTurboPath)} + mtplxSystemInfo={workspace.system.mtplx} + onInstallMtplx={() => void handleInstallMtplx()} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} onInstallPackage={handleInstallPackage} installingPackage={installingPackage} installLogs={installLogs} @@ -1604,6 +1623,10 @@ export default function App() { availableCacheStrategies={workspace.system.availableCacheStrategies} dflashInfo={workspace.system.dflash} turboInstalled={Boolean(workspace.system.llamaServerTurboPath)} + mtplxSystemInfo={workspace.system.mtplx} + onInstallMtplx={() => void handleInstallMtplx()} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} onInstallPackage={handleInstallPackage} installingPackage={installingPackage} installLogs={installLogs} @@ -1665,6 +1688,7 @@ export default function App() { availableCacheStrategies: workspace.system.availableCacheStrategies, llamaServerTurboPath: workspace.system.llamaServerTurboPath, dflash: workspace.system.dflash, + mtplx: workspace.system.mtplx, }, }} threadModelOptions={threadModelOptions} @@ -1680,6 +1704,9 @@ export default function App() { showBenchmarkModal={showBenchmarkModal} installingPackage={installingPackage} installLogs={installLogs} + onInstallMtplx={() => void handleInstallMtplx()} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} onBenchmarkDraftChange={updateBenchmarkDraft} onBenchmarkPromptIdChange={setBenchmarkPromptId} onBenchmarkModelKeyChange={setBenchmarkModelKey} @@ -1929,6 +1956,10 @@ export default function App() { installingPackage={installingPackage} installLogs={installLogs} turboInstalled={Boolean(workspace.system.llamaServerTurboPath)} + mtplxSystemInfo={workspace.system.mtplx} + onInstallMtplx={() => void handleInstallMtplx()} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} onPendingLaunchChange={setPendingLaunch} onLaunchModelSearchChange={setLaunchModelSearch} onLaunchSettingChange={updateLaunchSetting} diff --git a/src/api/index.ts b/src/api/index.ts index 9e61ee0..c8d28af 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -508,6 +508,9 @@ export { startGpuBundleInstall, startLongLiveInstall, startWanInstall, + getMtplxStatus, + startMtplxInstall, + getMtplxInstallStatus, } from "./setup"; export type { CudaTorchInstallAttempt, @@ -519,6 +522,9 @@ export type { InstallResult, LongLiveAttempt, LongLiveJobState, + MtplxAttempt, + MtplxJobState, + MtplxStatus, PromptEnhanceResult, TurboUpdateInfo, WanConvertStatusFields, diff --git a/src/api/setup.ts b/src/api/setup.ts index 4f7da27..b2ed9f4 100644 --- a/src/api/setup.ts +++ b/src/api/setup.ts @@ -272,6 +272,56 @@ export async function getWanInventory(): Promise { ); } +// --------------------------------------------------------------------------- +// MTPLX install (feature/mtplx) — isolated venv + forked mlx +// --------------------------------------------------------------------------- +// +// Same background-job shape as LongLiveJobState so the existing +// InstallLogPanel renders it without modification. + +export interface MtplxAttempt { + phase?: string; + package?: string; + indexUrl?: string; + ok: boolean; + output: string; +} + +export interface MtplxJobState { + id: string; + phase: "idle" | "preflight" | "creating-venv" | "installing" | "verifying" | "done" | "error"; + message: string; + packageCurrent: string | null; + packageIndex: number; + packageTotal: number; + percent: number; + targetDir: string | null; + error: string | null; + startedAt: number; + finishedAt: number; + attempts: MtplxAttempt[]; + done: boolean; +} + +export interface MtplxStatus { + installed: boolean; + version: string | null; + installedAt: string | null; + venvPath: string | null; +} + +export async function getMtplxStatus(): Promise { + return await fetchJson("/api/setup/mtplx-status", 8000); +} + +export async function startMtplxInstall(): Promise { + return await postJson("/api/setup/install-mtplx", {}, 15000); +} + +export async function getMtplxInstallStatus(): Promise { + return await fetchJson("/api/setup/install-mtplx/status", 10000); +} + // --------------------------------------------------------------------------- // llama-server-turbo update probe + capability refresh // --------------------------------------------------------------------------- diff --git a/src/components/InstallLogPanel.tsx b/src/components/InstallLogPanel.tsx index 8c2f5d7..b1a508f 100644 --- a/src/components/InstallLogPanel.tsx +++ b/src/components/InstallLogPanel.tsx @@ -1,15 +1,14 @@ import { useEffect, useRef } from "react"; import { useTranslation } from "react-i18next"; import type { TFunction } from "i18next"; -import type { GpuBundleJobState, LongLiveJobState } from "../api"; +import type { GpuBundleJobState, LongLiveJobState, MtplxJobState } from "../api"; -// The panel renders either kind of install job — GPU bundle or LongLive. -// They share the core fields (phase / message / attempts / progress -// counters / targetDir) and differ only in optional metadata. Treating -// the prop as a union keeps both Studio surfaces using one component -// instead of duplicating the auto-scroll, pip-noise filter, and +// The panel renders any background install job — GPU bundle, LongLive, or +// MTPLX. All share the core fields (phase / message / attempts / progress +// counters / targetDir). Treating the prop as a union keeps all surfaces +// using one component without duplicating auto-scroll, pip-noise filter, and // terminal layout. -export type InstallJobState = GpuBundleJobState | LongLiveJobState; +export type InstallJobState = GpuBundleJobState | LongLiveJobState | MtplxJobState; // Optional fields read by the meta line. ``GpuBundleJobState`` has these; // ``LongLiveJobState`` doesn't. Centralised here so the meta renderer @@ -26,7 +25,7 @@ interface InstallLogPanelProps { job: InstallJobState | null; // Title shown in the collapsed summary. Defaults to the GPU bundle // wording so existing call sites don't need to pass it. - variant?: "gpu-bundle" | "longlive"; + variant?: "gpu-bundle" | "longlive" | "mtplx"; } // Single scrollable terminal rendering the GPU bundle install progress. @@ -101,9 +100,11 @@ function InstallLogMeta({ job, t }: { job: InstallJobState; t: TFunction }) { return
{fragments.join(" · ")}
; } -function formatStatusLabel(job: InstallJobState, variant: "gpu-bundle" | "longlive", t: TFunction): string { +function formatStatusLabel(job: InstallJobState, variant: "gpu-bundle" | "longlive" | "mtplx", t: TFunction): string { const noun = variant === "longlive" ? t("installLog.statusNoun.longlive", { defaultValue: "LongLive install" }) + : variant === "mtplx" + ? t("installLog.statusNoun.mtplx", { defaultValue: "MTPLX install" }) : t("installLog.statusNoun.gpuBundle", { defaultValue: "Install" }); if (job.phase === "error" || job.error) return t("installLog.status.failed", { noun, defaultValue: `${noun} failed — see log` }); if (job.phase === "done") return t("installLog.status.complete", { noun, defaultValue: `${noun} complete — see log` }); diff --git a/src/components/LaunchModal.tsx b/src/components/LaunchModal.tsx index 49e4fb1..131201b 100644 --- a/src/components/LaunchModal.tsx +++ b/src/components/LaunchModal.tsx @@ -2,6 +2,7 @@ import { useTranslation } from "react-i18next"; import { ModelLaunchModal } from "./ModelLaunchModal"; import type { LaunchPreferences, PreviewMetrics, StrategyInstallLog, SystemStats } from "../types"; import type { ChatModelOption } from "../types/chat"; +import type { MtplxJobState } from "../api"; export interface PendingLaunch { action: "chat" | "server" | "thread"; @@ -23,6 +24,10 @@ export interface LaunchModalProps { installingPackage: string | null; installLogs?: Record; turboInstalled?: boolean; + mtplxSystemInfo?: SystemStats["mtplx"]; + onInstallMtplx?: () => void; + installingMtplx?: boolean; + mtplxJob?: MtplxJobState | null; onPendingLaunchChange: (value: PendingLaunch | null | ((prev: PendingLaunch | null) => PendingLaunch | null)) => void; onLaunchModelSearchChange: (value: string) => void; onLaunchSettingChange: (key: K, value: LaunchPreferences[K]) => void; @@ -45,6 +50,10 @@ export function LaunchModal({ installingPackage, installLogs, turboInstalled, + mtplxSystemInfo, + onInstallMtplx, + installingMtplx, + mtplxJob, onPendingLaunchChange, onLaunchModelSearchChange, onLaunchSettingChange, @@ -89,6 +98,10 @@ export function LaunchModal({ installingPackage={installingPackage} installLogs={installLogs} turboInstalled={turboInstalled} + mtplxSystemInfo={mtplxSystemInfo} + onInstallMtplx={onInstallMtplx} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} onSelectedKeyChange={setSelectedLaunchKey} onSearchChange={onLaunchModelSearchChange} onSettingChange={onLaunchSettingChange} diff --git a/src/components/ModelLaunchModal.tsx b/src/components/ModelLaunchModal.tsx index 9612ca8..7da1c69 100644 --- a/src/components/ModelLaunchModal.tsx +++ b/src/components/ModelLaunchModal.tsx @@ -4,6 +4,8 @@ import { RuntimeControls } from "./RuntimeControls"; import { number, sizeLabel } from "../utils"; import type { LaunchPreferences, ModelCapabilities, PreviewMetrics, StrategyInstallLog, SystemStats } from "../types"; import type { ChatModelOption } from "../types/chat"; +import type { MtplxJobState } from "../api"; +import { candidateKeys } from "./runtimeSupport"; /** * Phase 2.11: typed capability badges for the picker. Mirrors the @@ -64,6 +66,10 @@ export interface ModelLaunchModalProps { installingPackage: string | null; installLogs?: Record; turboInstalled?: boolean; + mtplxSystemInfo?: SystemStats["mtplx"]; + onInstallMtplx?: () => void; + installingMtplx?: boolean; + mtplxJob?: MtplxJobState | null; onSelectedKeyChange: (key: string) => void; onSearchChange: (value: string) => void; onSettingChange: (key: K, value: LaunchPreferences[K]) => void; @@ -90,6 +96,10 @@ export function ModelLaunchModal({ installingPackage, installLogs, turboInstalled, + mtplxSystemInfo, + onInstallMtplx, + installingMtplx, + mtplxJob, onSelectedKeyChange, onSearchChange, onSettingChange, @@ -122,6 +132,14 @@ export function ModelLaunchModal({ const resolvedSelectedKey = selectedOption?.key ?? ""; const listVisible = showList || !selectedOption || search.length > 0; + const mtplxModelSupported = (() => { + if (!mtplxSystemInfo?.supportedModels?.length) return false; + const modelKeys = candidateKeys([selectedOption?.canonicalRepo, selectedOption?.modelRef]); + return mtplxSystemInfo.supportedModels.some((ref) => + candidateKeys([ref]).some((k) => modelKeys.includes(k)) + ); + })(); + return (
event.stopPropagation()}> @@ -232,6 +250,10 @@ export function ModelLaunchModal({ selectedCanonicalRepo={selectedOption?.canonicalRepo} selectedModelName={selectedOption?.model} turboInstalled={turboInstalled} + mtplxInfo={mtplxSystemInfo ? { available: mtplxSystemInfo.available, modelSupported: mtplxModelSupported } : undefined} + onInstallMtplx={onInstallMtplx} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} compact />
diff --git a/src/components/RuntimeControls.tsx b/src/components/RuntimeControls.tsx index 288f15d..f4948b0 100644 --- a/src/components/RuntimeControls.tsx +++ b/src/components/RuntimeControls.tsx @@ -1,6 +1,8 @@ import { useEffect, useState } from "react"; import { useTranslation } from "react-i18next"; import type { LaunchPreferences, PreviewMetrics, StrategyInstallLog } from "../types"; +import type { MtplxJobState } from "../api"; +import { InstallLogPanel } from "./InstallLogPanel"; import { SliderField } from "./SliderField"; import { PerformancePreview } from "./PerformancePreview"; import { @@ -136,6 +138,14 @@ interface RuntimeControlsProps { turboInstalled?: boolean; /** Whether an update is available for llama-server-turbo. */ turboUpdateAvailable?: boolean; + /** MTPLX install state + model compatibility for the current selection. */ + mtplxInfo?: { + available: boolean; + modelSupported: boolean; + }; + onInstallMtplx?: () => void; + installingMtplx?: boolean; + mtplxJob?: MtplxJobState | null; } function StrategyInstallTerminal({ @@ -232,6 +242,10 @@ export function RuntimeControls({ selectedModelName, turboInstalled, turboUpdateAvailable, + mtplxInfo, + onInstallMtplx, + installingMtplx, + mtplxJob, }: RuntimeControlsProps) { const { t } = useTranslation("runtime"); const effectiveMaxContext = Math.max(2048, maxContext ?? 262144); @@ -254,6 +268,13 @@ export function RuntimeControls({ const canInstallDflashForModel = dflashSupport.modelSupported === true; const dflashInstallLog = installLogs?.["dflash-mlx"] ?? installLogs?.dflash; const showDflashInstallTerminal = Boolean(dflashInstallLog || (!dflashInstalled && !isGgufBackend && canInstallDflashForModel && onInstallPackage)); + // Backend `_select_engine` routes to MtplxEngine whenever the model has MTP + // heads AND the MTPLX venv is installed — regardless of which checkbox the + // user clicked. Both toggles bind to the same speculativeDecoding flag, so + // showing DFlash alongside an installed MTPLX is confusing (it appears + // "ticked" even though MTPLX is what actually runs). Hide DFlash in that + // case; MTPLX takes precedence. + const mtplxSupersedesDflash = (mtplxInfo?.modelSupported ?? false) && (mtplxInfo?.available ?? false); const specActive = settings.speculativeDecoding && dflashAvailable; const strategies = (availableCacheStrategies ?? [{id: "native", name: "Native f16", available: true, bitRange: null, defaultBits: null, supportsFp16Layers: false}]) .filter((s) => !s.appliesTo || s.appliesTo.length === 0 || s.appliesTo.includes("text")); @@ -612,7 +633,7 @@ export function RuntimeControls({ ready in one click. ``canInstallDflashForModel`` is True whenever the model is in the draft map AND the runtime gap is the missing pip package. */} - {dflashAvailable || canInstallDflashForModel ? ( + {!mtplxSupersedesDflash && (dflashAvailable || canInstallDflashForModel) ? (
) : null} - {expandedInfo === "dflash" && (dflashAvailable || canInstallDflashForModel) ? ( + {expandedInfo === "dflash" && !mtplxSupersedesDflash && (dflashAvailable || canInstallDflashForModel) ? (

{t("dflash.body", { @@ -701,10 +722,10 @@ export function RuntimeControls({ ) : null}

) : null} - {showDflashInstallTerminal ? ( + {showDflashInstallTerminal && !mtplxSupersedesDflash ? ( ) : null} - {settings.speculativeDecoding && dflashAvailable ? ( + {settings.speculativeDecoding && dflashAvailable && !mtplxSupersedesDflash ? (
) : null} + {/* MTPLX: native in-model MTP speculative decoding. Hidden when the + model has no MTP heads (no install button helps that case). Shown + with an install button when the model is supported but the venv + is not yet installed. Uses the same speculativeDecoding field as + DFlash — the controller auto-routes to MtplxEngine when both the + model has MTP heads and the venv is installed. */} + {mtplxInfo?.modelSupported ? ( +
+ + {!mtplxInfo.available && onInstallMtplx ? ( + + ) : null} + +
+ ) : null} + {/* Show terminal during install + on error. Hide on "done" so a + completed install doesn't keep re-appearing on every modal open + — capabilities probe already surfaces the installed state via + the MTPLX section's status copy. */} + {mtplxJob && mtplxJob.phase !== "idle" && mtplxJob.phase !== "done" ? ( + + ) : null} + {expandedInfo === "mtplx" && mtplxInfo?.modelSupported ? ( +
+

+ {t("mtplx.body", { + defaultValue: + "MTPLX uses baked-in Multi-Token Prediction (MTP) heads trained directly into the model. " + + "At inference time the heads draft 1–3 tokens per forward pass, which are verified in the same pass — " + + "no separate draft model needed. Gives ~1.8–2.2x faster generation with zero quality loss.", + })} +

+
+ {t("mtplx.requiresLabel", { defaultValue: "Requires:" })} + + {t("mtplx.requiresBody", { + defaultValue: "Apple Silicon + MTPLX isolated venv. Model must have baked-in MTP heads (Qwen3.5/3.6, DeepSeek V3/R1, Qwen3-Coder-Next).", + })} + +
+
+ {t("mtplx.statusLabel", { defaultValue: "Status:" })} + + {mtplxInfo.available + ? t("mtplx.statusInstalled", { defaultValue: "Installed — active when speculativeDecoding is enabled." }) + : t("mtplx.statusNotInstalled", { defaultValue: "Not installed. Click Install MTPLX to set up the isolated venv (~500 MB)." })} + +
+
+ ) : null}
{showPreview ? ( ; + onInstallMtplx?: () => void; + installingMtplx?: boolean; + mtplxJob?: MtplxJobState | null; onBenchmarkDraftChange: (key: K, value: BenchmarkRunPayload[K]) => void; onBenchmarkPromptIdChange: (id: string) => void; onBenchmarkModelKeyChange: (key: string) => void; @@ -67,6 +72,9 @@ export function BenchmarkRunTab({ showBenchmarkModal, installingPackage, installLogs, + onInstallMtplx, + installingMtplx, + mtplxJob, onBenchmarkDraftChange, onBenchmarkPromptIdChange, onBenchmarkModelKeyChange, @@ -554,6 +562,10 @@ export function BenchmarkRunTab({ installingPackage={installingPackage} installLogs={installLogs} turboInstalled={Boolean(workspace.system.llamaServerTurboPath)} + mtplxSystemInfo={workspace.system.mtplx} + onInstallMtplx={onInstallMtplx} + installingMtplx={installingMtplx} + mtplxJob={mtplxJob} onSelectedKeyChange={(key) => { onBenchmarkModelKeyChange(key); }} diff --git a/src/features/chat/ChatSidebar.tsx b/src/features/chat/ChatSidebar.tsx index b7a2315..7f45f8a 100644 --- a/src/features/chat/ChatSidebar.tsx +++ b/src/features/chat/ChatSidebar.tsx @@ -19,7 +19,6 @@ export interface ChatSidebarProps { onCreateSession: () => void; onToggleThreadPin: (session: ChatSession) => void; onDeleteSession: (sessionId: string) => void; - onCompareMode: () => void; onToggleCollapsed: () => void; } @@ -33,7 +32,6 @@ export function ChatSidebar({ onCreateSession, onToggleThreadPin, onDeleteSession, - onCompareMode, onToggleCollapsed, }: ChatSidebarProps) { const { t } = useTranslation("chat"); @@ -49,15 +47,6 @@ export function ChatSidebar({ - + {ACCEL_FILTERS.map((af) => { + const count = searchResults.filter((f) => familyMatchesAccel(f, af.id)).length; + return ( + + ); + })} + {searchError ? (

{searchError}

diff --git a/src/hooks/useModels.ts b/src/hooks/useModels.ts index 526ccae..c603df7 100644 --- a/src/hooks/useModels.ts +++ b/src/hooks/useModels.ts @@ -45,6 +45,7 @@ export function useModels( const [activeDownloads, setActiveDownloads] = useState>({}); const [discoverCapFilter, setDiscoverCapFilter] = useState(null); const [discoverFormatFilter, setDiscoverFormatFilter] = useState(null); + const [discoverAccelFilter, setDiscoverAccelFilter] = useState(null); // Keep curated families in sync when workspace refreshes (without // retriggering the search effect, which would cancel in-flight API calls). @@ -251,6 +252,8 @@ export function useModels( setDiscoverCapFilter, discoverFormatFilter, setDiscoverFormatFilter, + discoverAccelFilter, + setDiscoverAccelFilter, hasActiveDownloads, handleDownloadModel, handleCancelModelDownload, diff --git a/src/hooks/useMtplxInstall.ts b/src/hooks/useMtplxInstall.ts new file mode 100644 index 0000000..2a00e61 --- /dev/null +++ b/src/hooks/useMtplxInstall.ts @@ -0,0 +1,85 @@ +import { useState, useCallback, useRef, useEffect } from "react"; +import { + getMtplxInstallStatus, + getMtplxStatus, + startMtplxInstall, + type MtplxJobState, + type MtplxStatus, +} from "../api"; + +const POLL_INTERVAL_MS = 1500; + +export interface UseMtplxInstallReturn { + mtplxJob: MtplxJobState | null; + mtplxStatus: MtplxStatus | null; + installingMtplx: boolean; + handleInstallMtplx: () => Promise; + refreshMtplxStatus: () => Promise; +} + +export function useMtplxInstall(): UseMtplxInstallReturn { + const [mtplxJob, setMtplxJob] = useState(null); + const [mtplxStatus, setMtplxStatus] = useState(null); + const [installingMtplx, setInstallingMtplx] = useState(false); + const pollRef = useRef | null>(null); + + const stopPoll = useCallback(() => { + if (pollRef.current !== null) { + clearInterval(pollRef.current); + pollRef.current = null; + } + }, []); + + const startPoll = useCallback(() => { + stopPoll(); + pollRef.current = setInterval(async () => { + try { + const state = await getMtplxInstallStatus(); + setMtplxJob(state); + if (state.done) { + stopPoll(); + setInstallingMtplx(false); + // Refresh installed status after job completes. + try { + const status = await getMtplxStatus(); + setMtplxStatus(status); + } catch { + // best-effort + } + } + } catch { + stopPoll(); + setInstallingMtplx(false); + } + }, POLL_INTERVAL_MS); + }, [stopPoll]); + + const handleInstallMtplx = useCallback(async () => { + setInstallingMtplx(true); + try { + const initialState = await startMtplxInstall(); + setMtplxJob(initialState); + startPoll(); + } catch (err) { + setInstallingMtplx(false); + throw err; + } + }, [startPoll]); + + const refreshMtplxStatus = useCallback(async () => { + try { + const status = await getMtplxStatus(); + setMtplxStatus(status); + } catch { + // best-effort + } + }, []); + + useEffect(() => { + return () => { + stopPoll(); + }; + }, [stopPoll]); + + return { mtplxJob, mtplxStatus, installingMtplx, handleInstallMtplx, refreshMtplxStatus }; +} diff --git a/src/types/system.ts b/src/types/system.ts index 14d9c9a..f71b1c3 100644 --- a/src/types/system.ts +++ b/src/types/system.ts @@ -76,6 +76,10 @@ export interface SystemStats { ddtreeAvailable?: boolean; supportedModels: string[]; }; + mtplx?: { + available: boolean; + supportedModels: string[]; + }; runningLlmProcesses: Array<{ pid: number; name: string; diff --git a/tests/fixtures/stub_mtplx_server.py b/tests/fixtures/stub_mtplx_server.py new file mode 100755 index 0000000..811e16f --- /dev/null +++ b/tests/fixtures/stub_mtplx_server.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +"""Stub MTPLX server for integration tests. + +Mimics the surface MtplxEngine talks to: + - ``mtplx start --model --port N`` CLI shape + - ``GET /health`` returns 200 OK + - ``POST /v1/chat/completions`` returns OpenAI-compatible response + (non-streaming JSON, or SSE when ``stream: true``) + +Does not implement actual MTP speculative decoding — just enough surface to +prove the engine boots, proxies a prompt, and parses the response. +""" + +from __future__ import annotations + +import argparse +import json +import sys +import time +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + + +_FAKE_REPLY = "stub-mtplx says hi" +_FAKE_STREAM_CHUNKS = ["stub", "-mtplx", " says", " hi"] + + +class _Handler(BaseHTTPRequestHandler): + def log_message(self, *_args, **_kwargs) -> None: + return + + def do_GET(self) -> None: + if self.path == "/health": + body = b'{"ok":true}' + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + return + self.send_error(404) + + def do_POST(self) -> None: + if self.path != "/v1/chat/completions": + self.send_error(404) + return + length = int(self.headers.get("Content-Length") or 0) + raw = self.rfile.read(length) if length else b"{}" + try: + payload = json.loads(raw or b"{}") + except json.JSONDecodeError: + self.send_error(400, "bad JSON") + return + + if payload.get("stream"): + self._stream_response(payload) + else: + self._json_response(payload) + + def _json_response(self, payload: dict) -> None: + prompt_tokens = sum(len(str(m.get("content", "")).split()) for m in payload.get("messages", [])) + completion_tokens = len(_FAKE_REPLY.split()) + body = json.dumps({ + "id": "stub-mtplx-1", + "model": payload.get("model", "stub"), + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": _FAKE_REPLY}, + "finish_reason": "stop", + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + }).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def _stream_response(self, payload: dict) -> None: + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.end_headers() + for chunk in _FAKE_STREAM_CHUNKS: + data = json.dumps({ + "id": "stub-mtplx-1", + "model": payload.get("model", "stub"), + "choices": [{"index": 0, "delta": {"content": chunk}, "finish_reason": None}], + }) + self.wfile.write(f"data: {data}\n\n".encode("utf-8")) + self.wfile.flush() + prompt_tokens = sum(len(str(m.get("content", "")).split()) for m in payload.get("messages", [])) + completion_tokens = len(_FAKE_STREAM_CHUNKS) + final = json.dumps({ + "id": "stub-mtplx-1", + "model": payload.get("model", "stub"), + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + }) + self.wfile.write(f"data: {final}\n\n".encode("utf-8")) + self.wfile.write(b"data: [DONE]\n\n") + self.wfile.flush() + + +def _serve(port: int, *, fail_mode: str | None = None) -> None: + if fail_mode == "crash-before-ready": + sys.stderr.write("stub crash: simulated startup failure\n") + sys.stderr.flush() + sys.exit(2) + if fail_mode == "delay": + time.sleep(0.5) + server = ThreadingHTTPServer(("127.0.0.1", port), _Handler) + try: + server.serve_forever(poll_interval=0.1) + finally: + server.server_close() + + +def main() -> None: + parser = argparse.ArgumentParser() + sub = parser.add_subparsers(dest="cmd", required=True) + start = sub.add_parser("start") + start.add_argument("--model", required=True) + start.add_argument("--port", type=int, required=True) + start.add_argument("--fail-mode", default=None) + args = parser.parse_args() + if args.cmd == "start": + _serve(args.port, fail_mode=args.fail_mode) + + +if __name__ == "__main__": + main() diff --git a/tests/test_mtplx_engine_integration.py b/tests/test_mtplx_engine_integration.py new file mode 100644 index 0000000..23dc2bf --- /dev/null +++ b/tests/test_mtplx_engine_integration.py @@ -0,0 +1,333 @@ +"""Integration tests for MtplxEngine using a stub mtplx server. + +Verifies the full spawn → /health probe → /v1/chat/completions round-trip +without needing the real MTPLX install or any MTP-bearing model on disk. + +Stub lives at ``tests/fixtures/stub_mtplx_server.py`` and implements the +minimal surface ``MtplxEngine`` talks to. +""" + +from __future__ import annotations + +import stat +import sys +import textwrap +import unittest +from pathlib import Path + +from backend_service.inference.base import BackendCapabilities +from backend_service.inference.mtplx_engine import MtplxEngine + + +_FIXTURES = Path(__file__).parent / "fixtures" +_STUB_SCRIPT = _FIXTURES / "stub_mtplx_server.py" + + +def _make_mtplx_wrapper(tmp_path: Path, *, fail_mode: str | None = None) -> Path: + """Write an executable wrapper that mimics ``mtplx`` CLI. + + MtplxEngine spawns ``[bin, "start", "--model", X, "--port", N]`` — so the + wrapper just forwards argv into the python stub. + """ + wrapper = tmp_path / "mtplx" + extra = f' --fail-mode {fail_mode}' if fail_mode else "" + wrapper.write_text(textwrap.dedent(f"""\ + #!/usr/bin/env bash + exec {sys.executable} {_STUB_SCRIPT} "$@"{extra} + """)) + wrapper.chmod(wrapper.stat().st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + return wrapper + + +def _make_capabilities(mtplx_python: str) -> BackendCapabilities: + return BackendCapabilities( + pythonExecutable=sys.executable, + mlxAvailable=True, + mlxLmAvailable=True, + mlxUsable=True, + mtplxAvailable=True, + mtplxPythonPath=mtplx_python, + ) + + +class MtplxEngineIntegrationTests(unittest.TestCase): + def setUp(self) -> None: + import tempfile + + self._tmp = tempfile.TemporaryDirectory() + self.tmp_path = Path(self._tmp.name) + self.wrapper = _make_mtplx_wrapper(self.tmp_path) + # Place a fake python sibling so capabilities resolver is satisfied + (self.tmp_path / "python").write_text("#!/usr/bin/env bash\n") + (self.tmp_path / "python").chmod(0o755) + self.capabilities = _make_capabilities(str(self.tmp_path / "python")) + + def tearDown(self) -> None: + self._tmp.cleanup() + + def _make_engine(self) -> MtplxEngine: + engine = MtplxEngine(self.capabilities) + # Patch _mtplx_bin so the engine resolves to the test wrapper rather + # than the production ~/.chaosengine/mtplx-venv path. + engine._mtplx_bin = lambda: str(self.wrapper) # type: ignore[method-assign] + return engine + + def test_load_model_starts_server_and_returns_info(self) -> None: + engine = self._make_engine() + try: + info = engine.load_model( + model_ref="Qwen/Qwen3.5-7B", + model_name="Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + source="catalog", + backend="mtplx", + path=None, + runtime_target=None, + cache_strategy="native", + cache_bits=0, + fp16_layers=0, + fused_attention=False, + fit_model_in_memory=True, + context_tokens=8192, + speculative_decoding=True, + ) + self.assertEqual(info.ref, "Qwen/Qwen3.5-7B") + self.assertEqual(info.engine, "mtplx") + self.assertTrue(info.speculativeDecoding) + self.assertIsNotNone(info.runtimeNote) + self.assertIn("MTPLX", info.runtimeNote or "") + self.assertIn("draft tokens", info.runtimeNote or "") + self.assertIsNotNone(engine.port) + self.assertIsNotNone(engine.process_pid()) + finally: + engine.unload_model() + self.assertIsNone(engine.process_pid()) + + def test_load_model_raises_when_capability_missing(self) -> None: + caps = BackendCapabilities(pythonExecutable=sys.executable, mlxAvailable=True, mlxLmAvailable=True, mlxUsable=True, mtplxAvailable=False) + engine = MtplxEngine(caps) + engine._mtplx_bin = lambda: str(self.wrapper) # type: ignore[method-assign] + with self.assertRaises(RuntimeError) as ctx: + engine.load_model( + model_ref="Qwen/Qwen3.5-7B", + model_name="Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + source="catalog", + backend="mtplx", + path=None, + runtime_target=None, + cache_strategy="native", + cache_bits=0, + fp16_layers=0, + fused_attention=False, + fit_model_in_memory=True, + context_tokens=8192, + ) + self.assertIn("not installed", str(ctx.exception).lower()) + + def test_load_model_raises_when_server_exits_during_startup(self) -> None: + crash_dir = self.tmp_path / "crash" + crash_dir.mkdir() + bad_wrapper = _make_mtplx_wrapper(crash_dir, fail_mode="crash-before-ready") + engine = MtplxEngine(self.capabilities) + engine._mtplx_bin = lambda: str(bad_wrapper) # type: ignore[method-assign] + with self.assertRaises(RuntimeError): + engine.load_model( + model_ref="Qwen/Qwen3.5-7B", + model_name="Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + source="catalog", + backend="mtplx", + path=None, + runtime_target=None, + cache_strategy="native", + cache_bits=0, + fp16_layers=0, + fused_attention=False, + fit_model_in_memory=True, + context_tokens=8192, + ) + self.assertIsNone(engine.process_pid()) + + def test_generate_round_trip_returns_text_and_tokens(self) -> None: + engine = self._make_engine() + try: + engine.load_model( + model_ref="Qwen/Qwen3.5-7B", + model_name="Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + source="catalog", + backend="mtplx", + path=None, + runtime_target=None, + cache_strategy="native", + cache_bits=0, + fp16_layers=0, + fused_attention=False, + fit_model_in_memory=True, + context_tokens=8192, + ) + result = engine.generate( + prompt="Hello", + history=[], + system_prompt="You are a stub.", + max_tokens=32, + temperature=0.7, + ) + self.assertEqual(result.text, "stub-mtplx says hi") + self.assertEqual(result.finishReason, "stop") + self.assertGreater(result.completionTokens, 0) + self.assertGreater(result.promptTokens, 0) + self.assertGreater(result.tokS, 0.0) + self.assertIn("MTPLX", result.runtimeNote or "") + finally: + engine.unload_model() + + def test_stream_generate_yields_text_then_done(self) -> None: + engine = self._make_engine() + try: + engine.load_model( + model_ref="Qwen/Qwen3.5-7B", + model_name="Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + source="catalog", + backend="mtplx", + path=None, + runtime_target=None, + cache_strategy="native", + cache_bits=0, + fp16_layers=0, + fused_attention=False, + fit_model_in_memory=True, + context_tokens=8192, + ) + text_chunks = [] + done_chunk = None + for chunk in engine.stream_generate( + prompt="Hello", + history=[], + system_prompt=None, + max_tokens=32, + temperature=0.7, + ): + if chunk.text: + text_chunks.append(chunk.text) + if chunk.done: + done_chunk = chunk + joined = "".join(text_chunks) + self.assertIn("stub", joined) + self.assertIn("hi", joined) + self.assertIsNotNone(done_chunk) + assert done_chunk is not None + self.assertEqual(done_chunk.finish_reason, "stop") + self.assertGreater(done_chunk.completion_tokens, 0) + finally: + engine.unload_model() + + def test_generate_after_unload_raises(self) -> None: + engine = self._make_engine() + try: + engine.load_model( + model_ref="Qwen/Qwen3.5-7B", + model_name="Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + source="catalog", + backend="mtplx", + path=None, + runtime_target=None, + cache_strategy="native", + cache_bits=0, + fp16_layers=0, + fused_attention=False, + fit_model_in_memory=True, + context_tokens=8192, + ) + finally: + engine.unload_model() + with self.assertRaises(RuntimeError): + engine.generate( + prompt="Hello", + history=[], + system_prompt=None, + max_tokens=32, + temperature=0.7, + ) + + def test_unload_idempotent(self) -> None: + engine = self._make_engine() + engine.unload_model() + engine.unload_model() + self.assertIsNone(engine.process_pid()) + + +class MtplxEngineControllerFallbackTests(unittest.TestCase): + """Verify the controller falls back to MLXWorkerEngine when MTPLX startup fails.""" + + def test_controller_select_engine_picks_mtplx_when_model_supported(self) -> None: + from backend_service.inference.controller import RuntimeController + + controller = RuntimeController() + controller.capabilities = BackendCapabilities( + pythonExecutable=sys.executable, + mlxAvailable=True, + mlxLmAvailable=True, + mlxUsable=True, + mtplxAvailable=True, + ) + engine = controller._select_engine( + backend="mlx", + runtime_target=None, + path=None, + model_ref="Qwen/Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + speculative_decoding=True, + ) + self.assertEqual(engine.engine_name, "mtplx") + + def test_controller_select_engine_falls_through_when_model_not_supported(self) -> None: + from backend_service.inference.controller import RuntimeController + from backend_service.inference.mlx_engine import MLXWorkerEngine + + controller = RuntimeController() + controller.capabilities = BackendCapabilities( + pythonExecutable=sys.executable, + mlxAvailable=True, + mlxLmAvailable=True, + mlxUsable=True, + mtplxAvailable=True, + ) + engine = controller._select_engine( + backend="mlx", + runtime_target=None, + path=None, + model_ref="some/random-model-without-mtp", + canonical_repo="some/random-model-without-mtp", + speculative_decoding=True, + ) + self.assertIsInstance(engine, MLXWorkerEngine) + + def test_controller_select_engine_skips_mtplx_when_speculative_off(self) -> None: + from backend_service.inference.controller import RuntimeController + from backend_service.inference.mlx_engine import MLXWorkerEngine + + controller = RuntimeController() + controller.capabilities = BackendCapabilities( + pythonExecutable=sys.executable, + mlxAvailable=True, + mlxLmAvailable=True, + mlxUsable=True, + mtplxAvailable=True, + ) + engine = controller._select_engine( + backend="mlx", + runtime_target=None, + path=None, + model_ref="Qwen/Qwen3.5-7B", + canonical_repo="Qwen/Qwen3.5-7B", + speculative_decoding=False, + ) + self.assertIsInstance(engine, MLXWorkerEngine) + + +if __name__ == "__main__": + unittest.main()