Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,5 @@ docs/distillation/examples/

dmd_t2v_output/
preprocess_output_text/

# Next.js generated files
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
66 changes: 57 additions & 9 deletions fastvideo/entrypoints/video_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,27 @@ class VideoGenerator:
customization options, similar to popular frameworks like HF Diffusers.
"""

def __init__(self, fastvideo_args: FastVideoArgs,
executor_class: type[Executor], log_stats: bool):
def __init__(
self,
fastvideo_args: FastVideoArgs,
executor_class: type[Executor],
log_stats: bool,
*,
log_queue=None,
):
"""
Initialize the video generator.

Args:
fastvideo_args: The inference arguments
executor_class: The executor class to use for inference
log_stats: Whether to log statistics
log_queue: Optional multiprocessing.Queue to forward worker logs to
"""
self.fastvideo_args = fastvideo_args
self.executor = executor_class(fastvideo_args)
self.executor = executor_class(
fastvideo_args, log_queue=log_queue
)

@classmethod
def from_pretrained(cls, model_path: str, **kwargs) -> "VideoGenerator":
Expand All @@ -83,19 +93,25 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "VideoGenerator":
"""
# If users also provide some kwargs, it will override the FastVideoArgs and PipelineConfig.
kwargs['model_path'] = model_path
log_queue = kwargs.pop("log_queue", None)
fastvideo_args = FastVideoArgs.from_kwargs(**kwargs)

return cls.from_fastvideo_args(fastvideo_args)
return cls.from_fastvideo_args(fastvideo_args, log_queue=log_queue)

@classmethod
def from_fastvideo_args(cls,
fastvideo_args: FastVideoArgs) -> "VideoGenerator":
def from_fastvideo_args(
cls,
fastvideo_args: FastVideoArgs,
*,
log_queue=None,
) -> "VideoGenerator":
"""
Create a video generator with the specified arguments.

Args:
fastvideo_args: The inference arguments

log_queue: Optional multiprocessing.Queue to forward worker logs to

Returns:
The created video generator
"""
Expand All @@ -107,6 +123,7 @@ def from_fastvideo_args(cls,
fastvideo_args=fastvideo_args,
executor_class=executor_class,
log_stats=False, # TODO: implement
log_queue=log_queue,
)

def generate_video(
Expand Down Expand Up @@ -144,6 +161,37 @@ def generate_video(
A metadata dictionary for single-prompt generation, or a list of
metadata dictionaries for prompt-file batch generation.
"""
log_queue = kwargs.pop("log_queue", None)
if log_queue is not None and hasattr(
self.executor, "set_log_queue"
):
self.executor.set_log_queue(log_queue)
try:
return self._generate_video_impl(
prompt=prompt,
sampling_param=sampling_param,
mouse_cond=mouse_cond,
keyboard_cond=keyboard_cond,
grid_sizes=grid_sizes,
**kwargs,
)
finally:
if log_queue is not None and hasattr(
self.executor, "clear_log_queue"
):
self.executor.clear_log_queue()

def _generate_video_impl(
self,
prompt: str | None = None,
sampling_param: SamplingParam | None = None,
mouse_cond: torch.Tensor | None = None,
keyboard_cond: torch.Tensor | None = None,
grid_sizes: tuple[int, int, int] | list[int] | torch.Tensor
| None = None,
**kwargs,
) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]]:
"""Internal implementation of generate_video."""
# Handle batch processing from text file
if sampling_param is None:
sampling_param = SamplingParam.from_pretrained(
Expand Down
10 changes: 10 additions & 0 deletions fastvideo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,11 +660,21 @@ def get_sampling_param_cls_for_name(pipeline_name_or_path: str) -> Any | None:

_register_configs()


def get_registered_model_paths() -> list[str]:
"""Return all registered HuggingFace model paths.

Useful for UIs and tooling that need to enumerate supported models.
"""
return sorted(_MODEL_HF_PATH_TO_NAME.keys())


__all__ = [
"ConfigInfo",
"ModelInfo",
"get_model_info",
"get_pipeline_config_cls_from_name",
"get_registered_model_paths",
"get_sampling_param_cls_for_name",
"get_pipeline_config_classes",
]
8 changes: 7 additions & 1 deletion fastvideo/worker/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@

class Executor(ABC):

def __init__(self, fastvideo_args: FastVideoArgs):
def __init__(
self,
fastvideo_args: FastVideoArgs,
*,
log_queue=None,
):
self.fastvideo_args = fastvideo_args
self._log_queue = log_queue

self._init_executor()

Expand Down
114 changes: 108 additions & 6 deletions fastvideo/worker/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from dataclasses import dataclass
from enum import Enum
import faulthandler
import logging
import logging.handlers
import multiprocessing as mp
from multiprocessing.connection import Connection
from multiprocessing.queues import Queue
Expand Down Expand Up @@ -36,6 +38,11 @@
logger = init_logger(__name__)


def _make_queue_log_handler(log_queue: Queue) -> logging.Handler:
"""Create a QueueHandler that forwards fastvideo logs to a multiprocessing queue."""
return logging.handlers.QueueHandler(log_queue)


class StreamingTaskType(str, Enum):
"""
Enumeration for different streaming task types.
Expand Down Expand Up @@ -100,6 +107,7 @@ def _init_executor(self) -> None:
distributed_init_method=distributed_init_method,
streaming_input_queue=self._streaming_input_queue,
streaming_output_queue=self._streaming_output_queue,
log_queue=self._log_queue,
))

# Workers must be created before wait_for_ready to avoid
Expand Down Expand Up @@ -269,6 +277,14 @@ def merge_lora_weights(self) -> None:
if response["status"] != "lora_adapter_merged":
raise RuntimeError(f"Worker {i} failed to merge LoRA weights")

def set_log_queue(self, log_queue: Queue | None) -> None:
"""Forward worker logs to the given queue. Call before generate_video."""
self.collective_rpc("set_log_queue", kwargs={"log_queue": log_queue})

def clear_log_queue(self) -> None:
"""Stop forwarding worker logs to the queue. Call after generate_video."""
self.collective_rpc("clear_log_queue")

def collective_rpc(self,
method: str | Callable,
timeout: float | None = None,
Expand Down Expand Up @@ -327,6 +343,12 @@ def shutdown(self) -> None:
return # Prevent multiple shutdown calls

logger.info("Shutting down MultiprocExecutor...")

# Check if workers were initialized (they might not be if initialization failed)
if not hasattr(self, 'workers') or not self.workers:
logger.info("No workers to shut down.")
return

self.shutting_down = True

# First try gentle termination
Expand Down Expand Up @@ -462,11 +484,14 @@ def __init__(
pipe: Connection,
streaming_input_queue: Queue | None = None,
streaming_output_queue: Queue | None = None,
_initial_log_handler: logging.Handler | None = None,
**kwargs: Any,
):
self.rank = rank
self.pipe = pipe
self.streaming_input_queue = streaming_input_queue
self.streaming_output_queue = streaming_output_queue
self._initial_log_handler = _initial_log_handler
wrapper = WorkerWrapperBase(fastvideo_args=fastvideo_args,
rpc_rank=rank)

Expand Down Expand Up @@ -494,6 +519,7 @@ def make_worker_process(
distributed_init_method: str,
streaming_input_queue: Queue | None = None,
streaming_output_queue: Queue | None = None,
log_queue: Queue | None = None,
) -> UnreadyWorkerProcHandle:
context = get_mp_context()
executor_pipe, worker_pipe = context.Pipe(duplex=True)
Expand All @@ -508,6 +534,7 @@ def make_worker_process(
"ready_pipe": writer,
"streaming_input_queue": streaming_input_queue,
"streaming_output_queue": streaming_output_queue,
"log_queue": log_queue,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=WorkerMultiprocProc.worker_main,
Expand All @@ -524,6 +551,13 @@ def worker_main(*args, **kwargs):
""" Worker initialization and execution loops.
This runs a background process """

log_queue = kwargs.pop("log_queue", None)
# Add log handler before model loading so we capture fsdp_load, cuda, etc.
if log_queue is not None:
_handler = _make_queue_log_handler(log_queue)
logging.getLogger("fastvideo").addHandler(_handler)
kwargs["_initial_log_handler"] = _handler

# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
Expand Down Expand Up @@ -559,9 +593,21 @@ def signal_handler(signum, frame):

worker.worker_busy_loop()

except Exception:
except Exception as exc:
if ready_pipe is not None:
logger.exception("WorkerMultiprocProc failed to start.")
# Send error status to parent before closing pipe
try:
traceback_str = get_exception_traceback()
ready_pipe.send({
"status": "ERROR",
"error": str(exc),
"traceback": traceback_str,
"rank": rank,
})
except Exception:
# If sending fails, at least log it
pass
else:
logger.exception("WorkerMultiprocProc failed.")

Expand All @@ -571,7 +617,8 @@ def signal_handler(signum, frame):
shutdown_requested = True
traceback = get_exception_traceback()
logger.error("Worker %d hit an exception: %s", rank, traceback)
parent_process.send_signal(signal.SIGQUIT)
if parent_process:
parent_process.send_signal(signal.SIGQUIT)

finally:
if ready_pipe is not None:
Expand All @@ -592,6 +639,7 @@ def wait_for_ready(
pipes = {handle.ready_pipe: handle for handle in unready_proc_handles}
ready_proc_handles: list[WorkerProcHandle
| None] = ([None] * len(unready_proc_handles))
worker_errors: list[str] = []
while pipes:
ready = mp.connection.wait(pipes.keys())
for pipe in ready:
Expand All @@ -600,20 +648,42 @@ def wait_for_ready(
# Wait until the WorkerProc is ready.
unready_proc_handle = pipes.pop(pipe)
response: dict[str, Any] = pipe.recv()
if response["status"] != "READY":
if response["status"] == "ERROR":
# Worker sent error details
error_msg = response.get("error", "Unknown error")
traceback_str = response.get("traceback", "")
rank = response.get("rank", "unknown")
error_info = f"Worker {rank} error: {error_msg}"
if traceback_str:
error_info += f"\n{traceback_str}"
worker_errors.append(error_info)
# Log a concise error message (full traceback will be in the exception)
logger.error("Worker %s initialization failed: %s", rank, error_msg)
# Continue to check other workers, but we'll fail at the end
elif response["status"] != "READY":
worker_errors.append(f"Worker returned unexpected status: {response.get('status', 'unknown')}")
raise e

ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle))
if response["status"] == "READY":
ready_proc_handles[unready_proc_handle.rank] = (
WorkerProcHandle.from_unready_handle(
unready_proc_handle))

except EOFError:
# Pipe closed without sending status - worker crashed
worker_errors.append(f"Worker process crashed (pipe closed unexpectedly)")
e.__suppress_context__ = True
raise e from None

finally:
# Close connection.
pipe.close()

# If any workers failed, raise exception with details
if worker_errors:
error_msg = "WorkerMultiprocProc initialization failed due to exceptions in background processes:\n"
error_msg += "\n".join(f" - {err}" for err in worker_errors)
raise Exception(error_msg) from None

logger.info("%d workers ready", len(ready_proc_handles))
return cast(list[WorkerProcHandle], ready_proc_handles)
Expand All @@ -637,6 +707,14 @@ def worker_busy_loop(self) -> None:
with contextlib.suppress(Exception):
self.pipe.send(response)
break
if method == "set_log_queue":
self._set_log_queue(kwargs.get("log_queue"))
self.pipe.send({"status": "ok"})
continue
if method == "clear_log_queue":
self._clear_log_queue()
self.pipe.send({"status": "ok"})
continue
if method == "start_streaming_queue_loop":
self.pipe.send(
{"status": "streaming_queue_loop_started"})
Expand Down Expand Up @@ -722,6 +800,30 @@ def streaming_queue_loop(self) -> None:
self.streaming_output_queue.put(
StreamingResult(task_type=StreamingTaskType.STEP, error=e))

_log_queue_handler: logging.Handler | None = None

def _set_log_queue(self, log_queue: Queue | None) -> None:
"""Add a handler that forwards fastvideo logs to the given queue."""
self._clear_log_queue()
if log_queue is None:
return
# Remove initial handler if present (from worker_main) to avoid duplicates
if self._initial_log_handler is not None:
logging.getLogger("fastvideo").removeHandler(
self._initial_log_handler
)
self._initial_log_handler = None
self._log_queue_handler = _make_queue_log_handler(log_queue)
logging.getLogger("fastvideo").addHandler(self._log_queue_handler)

def _clear_log_queue(self) -> None:
"""Remove the log queue handler."""
if self._log_queue_handler is not None:
logging.getLogger("fastvideo").removeHandler(
self._log_queue_handler
)
self._log_queue_handler = None

@staticmethod
def setup_proc_title_and_log_prefix() -> None:
dp_size = get_dp_group().world_size
Expand Down
Loading