diff --git a/decart/__init__.py b/decart/__init__.py index b64d7f6..b7ce1d0 100644 --- a/decart/__init__.py +++ b/decart/__init__.py @@ -12,7 +12,7 @@ QueueResultError, TokenCreateError, ) -from .models import models, ModelDefinition, VideoRestyleInput +from .models import models, ModelDefinition, CustomModelDefinition, VideoRestyleInput from .types import FileInput, ModelState, Prompt from .queue import ( QueueClient, @@ -69,6 +69,7 @@ "QueueResultError", "models", "ModelDefinition", + "CustomModelDefinition", "VideoRestyleInput", "FileInput", "ModelState", diff --git a/decart/client.py b/decart/client.py index 6280ebe..20ceab7 100644 --- a/decart/client.py +++ b/decart/client.py @@ -1,9 +1,10 @@ import os +from types import TracebackType from typing import Any, Optional import aiohttp from pydantic import ValidationError from .errors import InvalidAPIKeyError, InvalidBaseURLError, InvalidInputError -from .models import ImageModelDefinition, _MODELS +from .models import ModelDefinition from .process.request import send_request from .queue.client import QueueClient from .tokens.client import TokensClient @@ -77,8 +78,7 @@ def __init__( @property def queue(self) -> QueueClient: """ - Queue client for async video editing jobs. - Only video models support the queue API. + Queue client for async jobs. Example: ```python @@ -128,25 +128,29 @@ async def close(self) -> None: if self._session and not self._session.closed: await self._session.close() - async def __aenter__(self): + async def __aenter__(self) -> "DecartClient": """Async context manager entry.""" return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Async context manager exit.""" await self.close() async def process(self, options: dict[str, Any]) -> bytes: """ - Process image editing synchronously. - Only image models support the process API. + Process synchronously using the model definition's configured endpoint. For video editing, use the queue API instead: result = await client.queue.submit_and_poll({...}) Args: options: Processing options including model and inputs - - model: ImageModelDefinition from models.image() + - model: ModelDefinition from models.image() or constructed directly - prompt: Text instructions describing the requested edit - Additional model-specific inputs @@ -154,21 +158,13 @@ async def process(self, options: dict[str, Any]) -> bytes: Generated/transformed image as bytes Raises: - InvalidInputError: If inputs are invalid or model is not an image model + InvalidInputError: If inputs are invalid ProcessingError: If processing fails """ if "model" not in options: raise InvalidInputError("model is required") - model: ImageModelDefinition = options["model"] - - # Validate that this is an image model (check against registry) - if model.name not in _MODELS["image"]: - raise InvalidInputError( - f"Model '{model.name}' is not supported by process(). " - f"Only image models support sync processing. " - f"For video models, use client.queue.submit_and_poll() instead." - ) + model: ModelDefinition[str] = options["model"] cancel_token = options.get("cancel_token") @@ -181,22 +177,25 @@ async def process(self, options: dict[str, Any]) -> bytes: file_inputs = {k: v for k, v in inputs.items() if k in FILE_FIELDS} non_file_inputs = {k: v for k, v in inputs.items() if k not in FILE_FIELDS} - # Validate non-file inputs and create placeholder for file fields - validation_inputs = { - **non_file_inputs, - **{k: b"" for k in file_inputs.keys()}, # Placeholder bytes for validation - } - - try: - validated_inputs = model.input_schema(**validation_inputs) - except ValidationError as e: - raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e - - # Build final inputs: validated non-file inputs + original file inputs - processed_inputs = { - **validated_inputs.model_dump(exclude_none=True), - **file_inputs, # Override placeholders with actual file data - } + if model.input_schema is None: + processed_inputs = {k: v for k, v in inputs.items() if v is not None} + else: + # Validate non-file inputs and create placeholder for file fields + validation_inputs = { + **non_file_inputs, + **{k: b"" for k in file_inputs.keys()}, # Placeholder bytes for validation + } + + try: + validated_inputs = model.input_schema(**validation_inputs) + except ValidationError as e: + raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e + + # Build final inputs: validated non-file inputs + original file inputs + processed_inputs = { + **validated_inputs.model_dump(exclude_none=True), + **file_inputs, # Override placeholders with actual file data + } session = await self._get_session() response = await send_request( diff --git a/decart/models.py b/decart/models.py index 93f8932..ebb0447 100644 --- a/decart/models.py +++ b/decart/models.py @@ -4,7 +4,6 @@ from .errors import ModelNotFoundError from .types import FileInput, MotionTrajectoryInput - RealTimeModels = Literal[ # Canonical names "lucy", @@ -92,7 +91,7 @@ class ModelDefinition(DecartBaseModel, Generic[ModelT]): fps: int = Field(ge=1) width: int = Field(ge=1) height: int = Field(ge=1) - input_schema: type[BaseModel] + input_schema: Optional[type[BaseModel]] = None # Type aliases for model definitions that support specific APIs @@ -105,6 +104,9 @@ class ModelDefinition(DecartBaseModel, Generic[ModelT]): RealTimeModelDefinition = ModelDefinition[RealTimeModels] """Type alias for model definitions that support realtime streaming.""" +CustomModelDefinition = ModelDefinition[str] +"""Type alias for model definitions with arbitrary (non-registry) model names.""" + class VideoToVideoInput(DecartBaseModel): prompt: str = Field( @@ -193,7 +195,6 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=704, - input_schema=BaseModel, ), "lucy-2.1": ModelDefinition( name="lucy-2.1", @@ -201,7 +202,6 @@ class ImageToImageInput(DecartBaseModel): fps=20, width=1088, height=624, - input_schema=BaseModel, ), "lucy-2.1-vton": ModelDefinition( name="lucy-2.1-vton", @@ -209,7 +209,6 @@ class ImageToImageInput(DecartBaseModel): fps=20, width=1088, height=624, - input_schema=BaseModel, ), "lucy-restyle": ModelDefinition( name="lucy-restyle", @@ -217,7 +216,6 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=704, - input_schema=BaseModel, ), "lucy-restyle-2": ModelDefinition( name="lucy-restyle-2", @@ -225,7 +223,6 @@ class ImageToImageInput(DecartBaseModel): fps=22, width=1280, height=704, - input_schema=BaseModel, ), "live-avatar": ModelDefinition( name="live-avatar", @@ -233,7 +230,6 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=720, - input_schema=BaseModel, ), # Latest aliases (server-side resolution) "lucy-latest": ModelDefinition( @@ -242,7 +238,6 @@ class ImageToImageInput(DecartBaseModel): fps=20, width=1088, height=624, - input_schema=BaseModel, ), "lucy-vton-latest": ModelDefinition( name="lucy-vton-latest", @@ -250,7 +245,6 @@ class ImageToImageInput(DecartBaseModel): fps=20, width=1088, height=624, - input_schema=BaseModel, ), "lucy-restyle-latest": ModelDefinition( name="lucy-restyle-latest", @@ -258,7 +252,6 @@ class ImageToImageInput(DecartBaseModel): fps=22, width=1280, height=704, - input_schema=BaseModel, ), # Deprecated names "mirage": ModelDefinition( @@ -267,7 +260,6 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=704, - input_schema=BaseModel, ), "mirage_v2": ModelDefinition( name="mirage_v2", @@ -275,7 +267,6 @@ class ImageToImageInput(DecartBaseModel): fps=22, width=1280, height=704, - input_schema=BaseModel, ), "lucy_v2v_720p_rt": ModelDefinition( name="lucy_v2v_720p_rt", @@ -283,7 +274,6 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=704, - input_schema=BaseModel, ), "live_avatar": ModelDefinition( name="live_avatar", @@ -291,7 +281,6 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=720, - input_schema=BaseModel, ), }, "video": { diff --git a/decart/process/request.py b/decart/process/request.py index 49ff7d2..71034fb 100644 --- a/decart/process/request.py +++ b/decart/process/request.py @@ -80,7 +80,7 @@ async def send_request( session: aiohttp.ClientSession, base_url: str, api_key: str, - model: ModelDefinition, + model: ModelDefinition[str], inputs: dict[str, Any], cancel_token: Optional[asyncio.Event] = None, integration: Optional[str] = None, diff --git a/decart/queue/client.py b/decart/queue/client.py index dfb79db..ef14a77 100644 --- a/decart/queue/client.py +++ b/decart/queue/client.py @@ -4,7 +4,7 @@ import aiohttp from pydantic import ValidationError -from ..models import VideoModelDefinition, _MODELS +from ..models import ModelDefinition from ..errors import InvalidInputError from .request import submit_job, get_job_status, get_job_content from .types import ( @@ -25,8 +25,7 @@ class QueueClient: """ - Queue client for async job-based video editing. - Only video models support the queue API. + Queue client for async jobs. Jobs are submitted and processed asynchronously, allowing you to poll for status and retrieve results when ready. @@ -62,13 +61,12 @@ async def _get_session(self) -> aiohttp.ClientSession: async def submit(self, options: dict[str, Any]) -> JobSubmitResponse: """ - Submit a video editing job to the queue for async processing. - Only video models are supported. + Submit an async queue job. Returns immediately with job_id and initial status. Args: options: Submit options including model and inputs - - model: VideoModelDefinition from models.video() + - model: VideoModelDefinition from models.video(), or a custom ModelDefinition - prompt: Text instructions describing the requested edit - Additional model-specific inputs @@ -76,21 +74,13 @@ async def submit(self, options: dict[str, Any]) -> JobSubmitResponse: JobSubmitResponse with job_id and status Raises: - InvalidInputError: If inputs are invalid or model is not a video model + InvalidInputError: If inputs are invalid QueueSubmitError: If submission fails """ if "model" not in options: raise InvalidInputError("model is required") - model: VideoModelDefinition = options["model"] - - # Validate that this is a video model (check against registry) - if model.name not in _MODELS["video"]: - raise InvalidInputError( - f"Model '{model.name}' is not supported by queue API. " - f"Only video models support async queue processing. " - f"For image models, use client.process() instead." - ) + model: ModelDefinition[str] = options["model"] inputs = {k: v for k, v in options.items() if k not in ("model", "cancel_token")} @@ -101,22 +91,25 @@ async def submit(self, options: dict[str, Any]) -> JobSubmitResponse: file_inputs = {k: v for k, v in inputs.items() if k in FILE_FIELDS} non_file_inputs = {k: v for k, v in inputs.items() if k not in FILE_FIELDS} - # Validate non-file inputs - validation_inputs = { - **non_file_inputs, - **{k: b"" for k in file_inputs.keys()}, - } - - try: - validated_inputs = model.input_schema(**validation_inputs) - except ValidationError as e: - raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e - - # Build final inputs - processed_inputs = { - **validated_inputs.model_dump(exclude_none=True), - **file_inputs, - } + if model.input_schema is None: + processed_inputs = {k: v for k, v in inputs.items() if v is not None} + else: + # Validate non-file inputs + validation_inputs = { + **non_file_inputs, + **{k: b"" for k in file_inputs.keys()}, + } + + try: + validated_inputs = model.input_schema(**validation_inputs) + except ValidationError as e: + raise InvalidInputError(f"Invalid inputs for {model.name}: {str(e)}") from e + + # Build final inputs + processed_inputs = { + **validated_inputs.model_dump(exclude_none=True), + **file_inputs, + } session = await self._get_session() return await submit_job( diff --git a/decart/queue/request.py b/decart/queue/request.py index c242b62..3f30f70 100644 --- a/decart/queue/request.py +++ b/decart/queue/request.py @@ -12,13 +12,13 @@ async def submit_job( session: aiohttp.ClientSession, base_url: str, api_key: str, - model: ModelDefinition, + model: ModelDefinition[str], inputs: dict[str, Any], integration: Optional[str] = None, ) -> JobSubmitResponse: """Submit a job to the queue. - POST /v1/jobs/{model} + POST {model.url_path} """ form_data = aiohttp.FormData() @@ -30,7 +30,7 @@ async def submit_job( else: form_data.add_field(key, str(value)) - endpoint = f"{base_url}/v1/jobs/{model.name}" + endpoint = f"{base_url}{model.url_path}" async with session.post( endpoint, diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 93f1394..77ffa1d 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -19,7 +19,6 @@ ) from .types import ConnectionState, RealtimeConnectOptions from ..types import FileInput -from ..models import RealTimeModels from ..errors import DecartSDKError, InvalidInputError, WebRTCError from ..process.request import file_input_to_bytes @@ -111,7 +110,7 @@ async def connect( ws_url = f"{base_url}{options.model.url_path}" ws_url += f"?api_key={quote(api_key)}&model={quote(options.model.name)}" - model_name: RealTimeModels = options.model.name # type: ignore[assignment] + model_name: str = options.model.name is_avatar_live = model_name in ("live_avatar", "live-avatar") audio_stream_manager: Optional[AudioStreamManager] = None diff --git a/tests/test_models.py b/tests/test_models.py index dc040cf..e253d77 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,6 +1,6 @@ import warnings import pytest -from decart import models, DecartSDKError +from decart import models, DecartSDKError, ModelDefinition from decart.models import _warned_aliases @@ -226,6 +226,19 @@ def test_latest_aliases_no_deprecation_warning() -> None: assert len(w) == 0 +def test_custom_model_definition_allows_arbitrary_model_names() -> None: + model = ModelDefinition( + name="lucy_2_rt_preview", + url_path="/v1/stream", + fps=20, + width=1280, + height=720, + ) + + assert model.name == "lucy_2_rt_preview" + assert model.input_schema is None + + def test_invalid_model() -> None: with pytest.raises(DecartSDKError): models.video("invalid-model") diff --git a/tests/test_process.py b/tests/test_process.py index 5c47a57..a1253a7 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,13 +1,9 @@ -""" -Tests for the process API. -Note: process() only supports image models (i2i). -Video models must use the queue API. -""" +"""Tests for the process API.""" import pytest import asyncio from unittest.mock import AsyncMock, patch, MagicMock -from decart import DecartClient, models, DecartSDKError +from decart import DecartClient, ModelDefinition, models, DecartSDKError @pytest.mark.asyncio @@ -74,20 +70,38 @@ async def test_process_image_to_image_with_reference_image() -> None: @pytest.mark.asyncio -async def test_process_rejects_video_models() -> None: - """Test that process() rejects video models with helpful error message.""" +async def test_process_accepts_custom_model_definition_without_schema() -> None: client = DecartClient(api_key="test-key") + custom_model = ModelDefinition( + name="lucy_image_preview", + url_path="/v1/generate/lucy_image_preview", + fps=25, + width=1280, + height=704, + ) - with pytest.raises(DecartSDKError) as exc_info: - await client.process( + with patch("decart.client.send_request", new_callable=AsyncMock) as mock_send: + mock_send.return_value = b"fake image data" + + result = await client.process( { - "model": models.video("lucy-clip"), - "prompt": "Add cinematic teal-and-orange grading", + "model": custom_model, + "prompt": "Apply a preview model treatment", + "data": b"fake image data", + "custom_strength": 0.7, + "optional": None, } ) - assert "not supported by process()" in str(exc_info.value) - assert "queue" in str(exc_info.value).lower() + assert result == b"fake image data" + mock_send.assert_called_once() + call_kwargs = mock_send.call_args.kwargs + assert call_kwargs["model"] is custom_model + assert call_kwargs["inputs"] == { + "prompt": "Apply a preview model treatment", + "data": b"fake image data", + "custom_strength": 0.7, + } @pytest.mark.asyncio diff --git a/tests/test_queue.py b/tests/test_queue.py index 73280eb..b45ed0c 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -1,12 +1,14 @@ -""" -Tests for the queue API. -Note: queue API only supports video models. -Image models must use the process API. -""" +"""Tests for the queue API.""" import pytest from unittest.mock import AsyncMock, patch, MagicMock -from decart import DecartClient, models, DecartSDKError +from decart import ( + DecartClient, + ModelDefinition, + models, + DecartSDKError, + QueueSubmitError, +) @pytest.mark.asyncio @@ -53,20 +55,38 @@ async def test_queue_submit_video_to_video() -> None: @pytest.mark.asyncio -async def test_queue_rejects_image_models() -> None: - """Test that queue API rejects image models with helpful error message.""" +async def test_queue_submit_accepts_custom_model_definition_without_schema() -> None: client = DecartClient(api_key="test-key") + custom_model = ModelDefinition( + name="lucy_video_preview", + url_path="/v1/jobs/lucy_video_preview", + fps=20, + width=1280, + height=720, + ) - with pytest.raises(DecartSDKError) as exc_info: - await client.queue.submit( + with patch("decart.queue.client.submit_job") as mock_submit: + mock_submit.return_value = MagicMock(job_id="job-custom", status="pending") + + job = await client.queue.submit( { - "model": models.image("lucy-image-2"), - "prompt": "Apply a painterly sunset color grade", + "model": custom_model, + "prompt": "Use the custom video model", + "data": b"fake video data", + "custom_strength": 0.7, + "optional": None, } ) - assert "not supported by queue" in str(exc_info.value) - assert "process" in str(exc_info.value).lower() + assert job.job_id == "job-custom" + mock_submit.assert_called_once() + call_kwargs = mock_submit.call_args.kwargs + assert call_kwargs["model"] is custom_model + assert call_kwargs["inputs"] == { + "prompt": "Use the custom video model", + "data": b"fake video data", + "custom_strength": 0.7, + } @pytest.mark.asyncio @@ -267,6 +287,38 @@ async def test_queue_includes_user_agent_header() -> None: assert "User-Agent" in headers assert headers["User-Agent"].startswith("decart-python-sdk/") + assert mock_session.post.call_args.args[0] == "https://api.decart.ai/v1/jobs/lucy-clip" + + +@pytest.mark.asyncio +async def test_queue_submit_surfaces_backend_error() -> None: + client = DecartClient(api_key="test-key") + + with patch("aiohttp.ClientSession") as mock_session_cls: + mock_response = MagicMock() + mock_response.ok = False + mock_response.status = 400 + mock_response.text = AsyncMock(return_value="unsupported model") + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock() + mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response) + mock_session.post.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_session_cls.return_value = mock_session + + with pytest.raises(QueueSubmitError) as exc_info: + await client.queue.submit( + { + "model": models.video("lucy-clip"), + "prompt": "Apply a cinematic grade", + "data": b"fake video data", + } + ) + + assert "unsupported model" in str(exc_info.value) # Tests for lucy-2.1 diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 49dc054..0486bcc 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -2,7 +2,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch -from decart import DecartClient, models +from decart import DecartClient, ModelDefinition, models try: from decart.realtime.client import RealtimeClient @@ -116,6 +116,54 @@ async def test_realtime_client_creation_with_mock(): assert realtime_client.subscribe_token is not None +@pytest.mark.asyncio +async def test_realtime_connect_accepts_custom_model_definition(): + """Custom realtime models can use arbitrary model names, matching the JS SDK escape hatch.""" + client = DecartClient(api_key="test-key") + custom_model = ModelDefinition( + name="lucy_2_rt_preview", + url_path="/v1/stream", + fps=20, + width=1280, + height=720, + ) + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.is_connected = MagicMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.realtime_base_url, + api_key=client.api_key, + local_track=MagicMock(), + options=RealtimeConnectOptions( + model=custom_model, + on_remote_stream=lambda t: None, + ), + ) + + assert realtime_client is not None + call_args = mock_manager_class.call_args + config = call_args[0][0] if call_args[0] else call_args[1]["configuration"] + assert "model=lucy_2_rt_preview" in config.webrtc_url + assert config.model_name == "lucy_2_rt_preview" + assert config.fps == 20 + + await realtime_client.disconnect() + + @pytest.mark.asyncio async def test_realtime_set_prompt_with_mock(): """Test set_prompt with mocked WebRTC and prompt_ack"""