From 5bfc33f69e2b7abcfafd8daf7769f2d74905d34b Mon Sep 17 00:00:00 2001 From: Garrick Date: Thu, 19 Mar 2026 17:49:15 -0500 Subject: [PATCH] Add @torch.inference_mode() to pipeline __call__ methods (fixes #152) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All 7 pipeline classes were missing the decorator on __call__, causing torch to retain autograd graphs when called from Python (not CLI). This leads to OOM — the text encoder's ~37 GB of activations aren't freed before the transformer loads. Only ti2vid_two_stages_hq.py already had the decorator. The main() functions in each file had inference_mode, but __call__ did not — so CLI usage worked but Python API usage OOMed. Co-Authored-By: Claude Opus 4.6 (1M context) --- packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py | 1 + packages/ltx-pipelines/src/ltx_pipelines/distilled.py | 1 + packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py | 1 + .../ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py | 1 + packages/ltx-pipelines/src/ltx_pipelines/retake.py | 1 + packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py | 1 + packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py | 1 + 7 files changed, 7 insertions(+) diff --git a/packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py b/packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py index 8dcc8c4e..c206e4e1 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/a2vid_two_stage.py @@ -77,6 +77,7 @@ def __init__( device=device, ) + @torch.inference_mode() def __call__( # noqa: PLR0913 self, prompt: str, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/distilled.py b/packages/ltx-pipelines/src/ltx_pipelines/distilled.py index aa01e8d3..5e849549 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/distilled.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/distilled.py @@ -73,6 +73,7 @@ def __init__( device=device, ) + @torch.inference_mode() def __call__( self, prompt: str, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py b/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py index 2fa7e874..b244be29 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/ic_lora.py @@ -109,6 +109,7 @@ def __init__( ) self.reference_downscale_factor = scale + @torch.inference_mode() def __call__( # noqa: PLR0913 self, prompt: str, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py b/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py index 77d18ac8..600bae5a 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/keyframe_interpolation.py @@ -78,6 +78,7 @@ def __init__( device=device, ) + @torch.inference_mode() def __call__( # noqa: PLR0913 self, prompt: str, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/retake.py b/packages/ltx-pipelines/src/ltx_pipelines/retake.py index 3eef52f7..1d71bef9 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/retake.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/retake.py @@ -197,6 +197,7 @@ def __init__( # Public entry point # # --------------------------------------------------------------------- # + @torch.inference_mode() def __call__( # noqa: PLR0913, PLR0915 self, video_path: str, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py index df73c0d6..c5a85807 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_one_stage.py @@ -68,6 +68,7 @@ def __init__( device=device, ) + @torch.inference_mode() def __call__( # noqa: PLR0913 self, prompt: str, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py index b486d539..64e100d8 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py @@ -79,6 +79,7 @@ def __init__( device=device, ) + @torch.inference_mode() def __call__( # noqa: PLR0913 self, prompt: str,