diff --git a/src/twinkle/server/gateway/twinkle_gateway_handlers.py b/src/twinkle/server/gateway/twinkle_gateway_handlers.py index 9a2c2f27..af5b0939 100644 --- a/src/twinkle/server/gateway/twinkle_gateway_handlers.py +++ b/src/twinkle/server/gateway/twinkle_gateway_handlers.py @@ -24,6 +24,14 @@ def _register_twinkle_routes(app: FastAPI, self_fn: Callable[[], GatewayServer]) -> None: """Register all /twinkle/* routes on the given FastAPI app.""" + @app.get('/twinkle/capacity_info', response_model=types.CapacityInfoResponse) + async def get_capacity_info( + request: Request, + self: GatewayServer = Depends(self_fn), + ) -> types.CapacityInfoResponse: + info = await self.state.get_capacity_info() + return types.CapacityInfoResponse(**info) + @app.get('/twinkle/healthz', response_model=types.HealthResponse) async def healthz(request: Request) -> types.HealthResponse: return types.HealthResponse(status='ok') diff --git a/src/twinkle/server/model/app.py b/src/twinkle/server/model/app.py index ecd841e3..5d0bc228 100644 --- a/src/twinkle/server/model/app.py +++ b/src/twinkle/server/model/app.py @@ -163,6 +163,10 @@ def get_self() -> ModelManagement: @asynccontextmanager async def lifespan(app: FastAPI): + try: + await get_self()._ensure_replica_registered() + except Exception as e: + logger.warning(f'Failed to register replica at startup: {e}') yield try: await get_self().shutdown() diff --git a/src/twinkle/server/model/tinker_handlers.py b/src/twinkle/server/model/tinker_handlers.py index e357d720..1f3d7ad9 100644 --- a/src/twinkle/server/model/tinker_handlers.py +++ b/src/twinkle/server/model/tinker_handlers.py @@ -41,7 +41,8 @@ async def create_model( async def _create_adapter(): _model_id = None try: - _model_id = await self.state.register_model(body.model_dump(), token=token, replica_id=self.replica_id) + _model_id = await self.state.register_model( + body.model_dump(), token=token, replica_id=self.replica_id, session_id=body.session_id) if body.lora_config: # TODO: Make LoraConfig more flexible lora_cfg = LoraConfig(r=body.lora_config.rank, target_modules='all-linear') diff --git a/src/twinkle/server/model/twinkle_handlers.py b/src/twinkle/server/model/twinkle_handlers.py index a441387b..ff161f5f 100644 --- a/src/twinkle/server/model/twinkle_handlers.py +++ b/src/twinkle/server/model/twinkle_handlers.py @@ -496,8 +496,6 @@ async def _task(): config = deserialize_object(body.config) extra_kwargs = body.model_extra or {} training_run_manager = create_training_run_manager(token, client_type='twinkle') - self.register_resource(adapter_name, token, session_id) - self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) lora_config = None if isinstance(config, LoraConfig): @@ -507,6 +505,20 @@ async def _task(): lora_config=lora_config, save_dir=resolved_save_dir, user_metadata={'adapter_name': body.adapter_name}) + await self.state.register_model( + run_config.model_dump(), + token=token, + model_id=adapter_name, + replica_id=self.replica_id, + session_id=session_id, + ) + try: + self.register_resource(adapter_name, token, session_id) + self.model.add_adapter_to_model(adapter_name, config, **extra_kwargs) + except Exception: + self.unregister_resource(adapter_name) + await self.state.unload_model(adapter_name) + raise training_run_manager.save(adapter_name, run_config) return {'status': 'ok', 'adapter_name': adapter_name} diff --git a/src/twinkle/server/utils/state/model_manager.py b/src/twinkle/server/utils/state/model_manager.py index 586e4868..2d5345f7 100644 --- a/src/twinkle/server/utils/state/model_manager.py +++ b/src/twinkle/server/utils/state/model_manager.py @@ -28,6 +28,20 @@ def __init__(self, expiration_timeout: float, per_token_model_limit: int = 30) - # replica_id -> max_loras limit declared at registration time self._replica_max_loras: dict[str, int] = {} + def get_capacity_info(self) -> dict[str, int]: + """Return global LoRA capacity across all registered replicas. + + Returns: + Dict containing 'max_loras', 'used_loras', and 'free_loras'. + """ + total_max_loras = sum(self._replica_max_loras.values()) + total_used_loras = sum(len(self._replica_models.get(rid, set())) for rid in self._replica_max_loras.keys()) + return { + 'max_loras': total_max_loras, + 'used_loras': total_used_loras, + 'free_loras': max(0, total_max_loras - total_used_loras), + } + # ----- Replica Registration ----- def register_replica(self, replica_id: str, max_loras: int) -> None: diff --git a/src/twinkle/server/utils/state/server_state.py b/src/twinkle/server/utils/state/server_state.py index 8e7689b2..fd3a7626 100644 --- a/src/twinkle/server/utils/state/server_state.py +++ b/src/twinkle/server/utils/state/server_state.py @@ -57,6 +57,9 @@ def __init__( self._metrics_running = False self._metrics_update_interval: float = float(kwargs.get('metrics_update_interval', 15.0)) + async def get_capacity_info(self) -> dict[str, int]: + return self._model_mgr.get_capacity_info() + # ----- Session Management ----- async def create_session(self, payload: dict[str, Any]) -> str: @@ -99,7 +102,8 @@ async def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None, - replica_id: str | None = None) -> str: + replica_id: str | None = None, + session_id: str | None = None) -> str: """Register a new model with the server state. Args: @@ -107,6 +111,8 @@ async def register_model(self, token: User token that owns this model. Required. model_id: Optional explicit model_id; otherwise auto-generated. replica_id: Optional replica that is hosting this model. + session_id: Optional owning session; enables cascade cleanup when + the session expires. Falls back to ``payload['session_id']``. Returns: The model_id for the registered model. @@ -117,7 +123,7 @@ async def register_model(self, _model_id = re.sub(r'[^\w\-]', '_', _model_id) record = ModelRecord( - session_id=payload.get('session_id'), + session_id=session_id or payload.get('session_id'), model_seq_id=payload.get('model_seq_id'), base_model=payload.get('base_model'), user_metadata=payload.get('user_metadata') or {}, @@ -374,6 +380,9 @@ class ServerStateProxy: def __init__(self, actor_handle) -> None: self._actor = actor_handle + async def get_capacity_info(self) -> dict[str, int]: + return await self._actor.get_capacity_info.remote() + # ----- Session Management ----- async def create_session(self, payload: dict[str, Any]) -> str: @@ -391,8 +400,9 @@ async def register_model(self, payload: dict[str, Any], token: str, model_id: str | None = None, - replica_id: str | None = None) -> str: - return await self._actor.register_model.remote(payload, token, model_id, replica_id) + replica_id: str | None = None, + session_id: str | None = None) -> str: + return await self._actor.register_model.remote(payload, token, model_id, replica_id, session_id) async def unload_model(self, model_id: str) -> bool: return await self._actor.unload_model.remote(model_id) diff --git a/src/twinkle_client/manager.py b/src/twinkle_client/manager.py index b9398997..12257d63 100644 --- a/src/twinkle_client/manager.py +++ b/src/twinkle_client/manager.py @@ -5,7 +5,7 @@ import threading from typing import Any, Dict, List, Optional, Tuple from twinkle import get_logger -from twinkle_client.types.server import (DeleteCheckpointResponse, GetServerCapabilitiesResponse) +from twinkle_client.types.server import (CapacityInfoResponse, DeleteCheckpointResponse, GetServerCapabilitiesResponse) from twinkle_client.types.session import (CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse) from twinkle_client.types.training import (Checkpoint, Cursor, ParsedCheckpointTwinklePath, TrainingRun, @@ -76,6 +76,21 @@ def __init__( self._heartbeat_thread.start() atexit.register(self.close) + def get_capacity_info(self) -> CapacityInfoResponse: + """ + Get the server's global LoRA capacity information. + + Returns: + :class:`~twinkle_client.types.server.CapacityInfoResponse` with + ``max_loras``, ``used_loras``, and ``free_loras`` fields. + + Raises: + TwinkleClientError: If the request fails. + """ + response = http_get(self._get_url('/capacity_info')) + data = self._handle_response(response) + return CapacityInfoResponse(**data) + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ diff --git a/src/twinkle_client/types/__init__.py b/src/twinkle_client/types/__init__.py index d58bb00c..49673b0e 100644 --- a/src/twinkle_client/types/__init__.py +++ b/src/twinkle_client/types/__init__.py @@ -76,6 +76,7 @@ SupportedModel, WeightsInfoRequest, WeightsInfoResponse as ServerWeightsInfoResponse, + CapacityInfoResponse, ) from .session import CreateSessionRequest, CreateSessionResponse, SessionHeartbeatRequest, SessionHeartbeatResponse from .training import ( diff --git a/src/twinkle_client/types/server.py b/src/twinkle_client/types/server.py index 2d9233b4..1c7c992d 100644 --- a/src/twinkle_client/types/server.py +++ b/src/twinkle_client/types/server.py @@ -40,3 +40,10 @@ class CheckpointPathResponse(BaseModel): """Response body for the /checkpoint_path endpoint.""" path: str twinkle_path: str + + +class CapacityInfoResponse(BaseModel): + """Response body for the /capacity_info endpoint.""" + max_loras: int + used_loras: int + free_loras: int