Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions backend_service/helpers/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def _get_cache_strategies():
from cache_compression import registry
return registry.available()

def _get_mtplx_info():
from backend_service.inference._mtp import MTP_MODEL_MAP, _MTP_ALIASES
from backend_service.inference.capabilities import _detect_mtplx
available, _ = _detect_mtplx()
supported_models = list(MTP_MODEL_MAP.keys()) + list(_MTP_ALIASES.keys())
return {"available": available, "supportedModels": supported_models}

def _get_dflash_info():
try:
from dflash import availability_info
Expand Down Expand Up @@ -173,6 +180,7 @@ def _get_dflash_info():
"appVersion": app_version,
"availableCacheStrategies": _get_cache_strategies(),
"dflash": _get_dflash_info(),
"mtplx": _get_mtplx_info(),
"vllmAvailable": native.get("vllmAvailable", False),
"vllmVersion": native.get("vllmVersion"),
"mlxAvailable": native["mlxAvailable"],
Expand Down
92 changes: 92 additions & 0 deletions backend_service/inference/_mtp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""MTP (Multi-Token Prediction) model registry.

Maps models that carry baked-in MTP heads to their recommended
``spec-draft-n-max`` count. Used by two separate inference paths:

- MLX path → ``MtplxEngine`` (this module, ``has_mtp_heads``)
- GGUF path → ``LlamaCppEngine._build_command`` (``--spec-type mtp``,
Phase 2 of feature/mtplx)

Only models trained with MTP objectives belong here — standard MLX quants
of base/instruct checkpoints that strip MTP heads at conversion time should
NOT be listed. MTPLX auto-detects MTP heads at load time; this map is used
by ChaosEngineAI to decide whether to offer the MTPLX toggle and, for the
GGUF path, how many draft tokens to request.
"""

from __future__ import annotations

# ---------------------------------------------------------------------------
# MLX / transformers repos with baked-in MTP heads
# ---------------------------------------------------------------------------
#
# Key → canonical HuggingFace repo id (case-sensitive)
# Value → recommended spec-draft-n-max (1–3); start conservatively at 1
# and bump when acceptance rate benchmarks justify it.

MTP_MODEL_MAP: dict[str, int] = {
# ----- Youssofal MTPLX-Optimized (upstream-verified for MTPLX v0.3.5) -----
"Youssofal/Qwen3.6-27B-MTPLX-Optimized-Speed": 1,
"Youssofal/Qwen3.6-27B-MTPLX-Optimized-Speed-FP16": 1,
"Youssofal/Qwen3.6-27B-MTPLX-Optimized-Quality": 1,
# ----- Qwen3.5 family -----
"Qwen/Qwen3.5-4B": 1,
"Qwen/Qwen3.5-7B": 1,
"Qwen/Qwen3.5-9B": 1,
"Qwen/Qwen3.5-14B": 1,
"Qwen/Qwen3.5-27B": 1,
"Qwen/Qwen3.5-35B-A3B": 1,
"Qwen/Qwen3.5-122B-A10B": 1,
# ----- Qwen3.6 family -----
"Qwen/Qwen3.6-27B": 1,
"Qwen/Qwen3.6-35B-A3B": 1,
# ----- Qwen3-Coder-Next -----
"Qwen/Qwen3-Coder-Next": 1,
# ----- DeepSeek V3 / R1 -----
"deepseek-ai/DeepSeek-V3": 1,
"deepseek-ai/DeepSeek-V3-0324": 1,
"deepseek-ai/DeepSeek-R1": 1,
}

# Community MLX conversions that preserve MTP heads.
# Maps community repo → canonical repo (for draft-n lookup).
_MTP_ALIASES: dict[str, str] = {
# Qwen3.5
"mlx-community/Qwen3.5-4B-4bit": "Qwen/Qwen3.5-4B",
"mlx-community/Qwen3.5-4B-8bit": "Qwen/Qwen3.5-4B",
"mlx-community/Qwen3.5-7B-4bit": "Qwen/Qwen3.5-7B",
"mlx-community/Qwen3.5-7B-8bit": "Qwen/Qwen3.5-7B",
"mlx-community/Qwen3.5-9B-4bit": "Qwen/Qwen3.5-9B",
"mlx-community/Qwen3.5-9B-8bit": "Qwen/Qwen3.5-9B",
"mlx-community/Qwen3.5-14B-4bit": "Qwen/Qwen3.5-14B",
"mlx-community/Qwen3.5-14B-8bit": "Qwen/Qwen3.5-14B",
"mlx-community/Qwen3.5-27B-4bit": "Qwen/Qwen3.5-27B",
"mlx-community/Qwen3.5-27B-8bit": "Qwen/Qwen3.5-27B",
# Qwen3.6
"mlx-community/Qwen3.6-27B-4bit": "Qwen/Qwen3.6-27B",
"mlx-community/Qwen3.6-27B-8bit": "Qwen/Qwen3.6-27B",
"mlx-community/Qwen3.6-27B-bf16": "Qwen/Qwen3.6-27B",
"mlx-community/Qwen3.6-35B-A3B-4bit": "Qwen/Qwen3.6-35B-A3B",
"lmstudio-community/Qwen3.6-27B-GGUF": "Qwen/Qwen3.6-27B",
# Qwen3-Coder-Next
"lmstudio-community/Qwen3-Coder-Next-MLX-4bit": "Qwen/Qwen3-Coder-Next",
}


def get_mtp_draft_n(repo: str) -> int | None:
"""Return the recommended spec-draft-n-max for *repo*, or None.

Returns None when the repo is not known to carry MTP heads — callers
should not enable MTP speculative decoding for that model.
"""
if repo in MTP_MODEL_MAP:
return MTP_MODEL_MAP[repo]
canonical = _MTP_ALIASES.get(repo)
if canonical:
return MTP_MODEL_MAP.get(canonical)
return None


def has_mtp_heads(repo: str) -> bool:
"""True when *repo* (or a community alias of it) carries baked-in MTP heads."""
return get_mtp_draft_n(repo) is not None
4 changes: 4 additions & 0 deletions backend_service/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class BackendCapabilities:
converterAvailable: bool = False
vllmAvailable: bool = False
vllmVersion: str | None = None
mtplxAvailable: bool = False
mtplxPythonPath: str | None = None
probing: bool = False

def to_dict(self) -> dict[str, Any]:
Expand All @@ -108,6 +110,8 @@ def to_dict(self) -> dict[str, Any]:
"converterAvailable": self.converterAvailable,
"vllmAvailable": self.vllmAvailable,
"vllmVersion": self.vllmVersion,
"mtplxAvailable": self.mtplxAvailable,
"mtplxPythonPath": self.mtplxPythonPath,
"probing": self.probing,
}

Expand Down
24 changes: 24 additions & 0 deletions backend_service/inference/capabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import time
from threading import RLock

from pathlib import Path

from backend_service.inference._constants import CAPABILITY_CACHE_TTL_SECONDS
from backend_service.inference.base import BackendCapabilities
from backend_service.inference.binaries import (
Expand All @@ -24,10 +26,26 @@
)


_MTPLX_VENV = Path.home() / ".chaosengine" / "mtplx-venv"
_MTPLX_VERSION_FILE = Path.home() / ".chaosengine" / "bin" / "mtplx.version"

_capability_cache: tuple[float, BackendCapabilities] | None = None
_capability_lock = RLock()


def _detect_mtplx() -> tuple[bool, str | None]:
"""Return (available, python_path) for the MTPLX isolated venv.

Cheap file-existence check — no subprocess spawn. The version file is
written by install-mtplx.sh on clean install; its presence together with
the venv python binary is sufficient to confirm a usable install.
"""
python = _MTPLX_VENV / "bin" / "python"
if _MTPLX_VERSION_FILE.exists() and python.exists():
return True, str(python)
return False, None


def _initial_backend_capabilities() -> BackendCapabilities:
"""Cheap capability placeholder used while the real probe runs.

Expand All @@ -40,6 +58,7 @@ def _initial_backend_capabilities() -> BackendCapabilities:
llama_server_path = _resolve_llama_server()
llama_server_turbo_path = _resolve_llama_server_turbo()
llama_cli_path = _resolve_llama_cli()
mtplx_available, mtplx_python = _detect_mtplx()
return BackendCapabilities(
pythonExecutable=python_executable,
mlxAvailable=False,
Expand All @@ -53,6 +72,8 @@ def _initial_backend_capabilities() -> BackendCapabilities:
converterAvailable=False,
vllmAvailable=False,
vllmVersion=None,
mtplxAvailable=mtplx_available,
mtplxPythonPath=mtplx_python,
probing=True,
)

Expand Down Expand Up @@ -80,6 +101,7 @@ def _probe_native_backends() -> BackendCapabilities:

from backend_service.vllm_engine import _vllm_importable, _vllm_version

mtplx_available, mtplx_python = _detect_mtplx()
return BackendCapabilities(
pythonExecutable=python_executable,
mlxAvailable=mlx_available,
Expand All @@ -95,6 +117,8 @@ def _probe_native_backends() -> BackendCapabilities:
converterAvailable=mlx_usable,
vllmAvailable=_vllm_importable(),
vllmVersion=_vllm_version(),
mtplxAvailable=mtplx_available,
mtplxPythonPath=mtplx_python,
)


Expand Down
91 changes: 70 additions & 21 deletions backend_service/inference/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
LlamaCppEngine,
_CACHE_TYPE_CACHE,
_LLAMA_HELP_CACHE,
_LLAMA_HELP_LOCK,
_LLAMA_SAMPLER_KEYS,
_STANDARD_CACHE_TYPES,
_apply_llama_chat_template_fixes,
Expand All @@ -99,10 +100,12 @@
_resolve_mmproj_path,
)
from backend_service.inference.mlx_engine import MLXWorkerEngine
from backend_service.inference.mtplx_engine import MtplxEngine
from backend_service.inference.simple_engines import (
MockInferenceEngine,
RemoteOpenAIEngine,
)
from backend_service.inference._mtp import has_mtp_heads


class RuntimeController:
Expand Down Expand Up @@ -338,7 +341,8 @@ def prune_stale_backend_children(self) -> None:

is_mlx_worker = "backend_service.mlx_worker" in cmdline or "mlx_worker" in cmdline
is_llama = "llama-server" in name or "llama-server" in cmdline
if not (is_mlx_worker or is_llama):
is_mtplx = "mtplx" in name or "/mtplx-venv/" in cmdline or " mtplx " in f" {cmdline} " or cmdline.endswith(" mtplx")
if not (is_mlx_worker or is_llama or is_mtplx):
continue
if child.pid in tracked:
continue
Expand All @@ -350,10 +354,16 @@ def prune_stale_backend_children(self) -> None:
if now_wall - create_time < self.ORPHAN_DETECTION_GRACE_SECONDS:
continue

if is_mlx_worker:
kind, label = "mlx_worker", "MLX worker"
elif is_mtplx:
kind, label = "mtplx", "MTPLX server"
else:
kind, label = "llama_server", "llama-server"
record = {
"pid": int(child.pid),
"kind": "mlx_worker" if is_mlx_worker else "llama_server",
"label": "MLX worker" if is_mlx_worker else "llama-server",
"kind": kind,
"label": label,
"action": "terminated",
"detectedAt": _now_label(),
# Internal monotonic stamp used for TTL; not serialized.
Expand Down Expand Up @@ -475,6 +485,9 @@ def _select_engine(
backend: str,
runtime_target: str | None,
path: str | None,
model_ref: str = "",
canonical_repo: str | None = None,
speculative_decoding: bool = False,
) -> BaseInferenceEngine:
hint = (backend or "auto").lower()
target = runtime_target or path
Expand All @@ -483,6 +496,12 @@ def _select_engine(
return RemoteOpenAIEngine(self.capabilities)
if hint == "mlx":
if self.capabilities.mlxUsable:
if (
speculative_decoding
and self.capabilities.mtplxAvailable
and has_mtp_heads(canonical_repo or model_ref)
):
return MtplxEngine(self.capabilities)
return MLXWorkerEngine(self.capabilities)
reason = self.capabilities.mlxMessage or "MLX is not available in this environment"
raise RuntimeError(
Expand Down Expand Up @@ -727,6 +746,9 @@ def _internal_progress(progress: dict[str, Any]) -> None:
backend=backend,
runtime_target=runtime_target,
path=path,
model_ref=model_ref,
canonical_repo=canonical_repo,
speculative_decoding=speculative_decoding,
)

# Never keep multiple warm copies of the same logical model under
Expand All @@ -739,25 +761,52 @@ def _internal_progress(progress: dict[str, Any]) -> None:
)

self.engine = selected_engine
_load_kwargs: dict[str, Any] = dict(
model_ref=model_ref,
model_name=resolved_name,
canonical_repo=canonical_repo,
source=source,
backend=self.engine.engine_name,
path=path,
runtime_target=runtime_target,
cache_strategy=cache_strategy,
cache_bits=cache_bits,
fp16_layers=fp16_layers,
fused_attention=fused_attention,
fit_model_in_memory=fit_model_in_memory,
context_tokens=context_tokens,
speculative_decoding=speculative_decoding,
tree_budget=tree_budget,
progress_callback=_internal_progress,
)
try:
loaded = self.engine.load_model(
model_ref=model_ref,
model_name=resolved_name,
canonical_repo=canonical_repo,
source=source,
backend=self.engine.engine_name,
path=path,
runtime_target=runtime_target,
cache_strategy=cache_strategy,
cache_bits=cache_bits,
fp16_layers=fp16_layers,
fused_attention=fused_attention,
fit_model_in_memory=fit_model_in_memory,
context_tokens=context_tokens,
speculative_decoding=speculative_decoding,
tree_budget=tree_budget,
progress_callback=_internal_progress,
)
loaded = self.engine.load_model(**_load_kwargs)
except RuntimeError as _mtplx_exc:
# MtplxEngine startup failure → fall back to standard MLX worker.
if isinstance(self.engine, MtplxEngine):
fallback = MLXWorkerEngine(self.capabilities)
self.engine = fallback
_load_kwargs["backend"] = fallback.engine_name
try:
loaded = self.engine.load_model(**_load_kwargs)
loaded.runtimeNote = _append_runtime_note(
loaded.runtimeNote,
f"MTPLX startup failed ({_mtplx_exc}); using standard MLX.",
)
except Exception:
self.loaded_model = None
self.runtime_note = None
self._loading_progress = None
self._loading_log_tail = []
self.prune_stale_backend_children()
raise
else:
self.loaded_model = None
self.runtime_note = None
self._loading_progress = None
self._loading_log_tail = []
self.prune_stale_backend_children()
raise
except Exception:
self.loaded_model = None
self.runtime_note = None
Expand Down
Loading