diff --git a/README.md b/README.md index 81b166ae..15667de4 100644 --- a/README.md +++ b/README.md @@ -14,22 +14,23 @@ limitations under the License. --> -[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) +[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? -- **`2026/1/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported +- **`2026/01/29`**: Wan LoRA for inference is now supported +- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported - **`2025/11/11`**: Wan2.2 txt2vid generation is now supported - **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported. - **`2025/10/14`**: NVIDIA DGX Spark Flux support. -- **`2025/8/14`**: LTX-Video img2vid generation is now supported. -- **`2025/7/29`**: LTX-Video text2vid generation is now supported. +- **`2025/08/14`**: LTX-Video img2vid generation is now supported. +- **`2025/07/29`**: LTX-Video text2vid generation is now supported. - **`2025/04/17`**: Flux Finetuning. - **`2025/02/12`**: Flux LoRA for inference. - **`2025/02/08`**: Flux schnell & dev inference. - **`2024/12/12`**: Load multiple LoRAs for inference. - **`2024/10/22`**: LoRA support for Hyper SDXL. -- **`2024/8/1`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. -- **`2024/7/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. +- **`2024/08/01`**: Orbax is the new default checkpointer. You can still use `pipeline.save_pretrained` after training to save in diffusers format. +- **`2024/07/20`**: Dreambooth training for Stable Diffusion 1.x,2.x is now supported. # Overview @@ -68,14 +69,15 @@ MaxDiffusion supports - [SD 1.4](#stable-diffusion-14-training) - [Dreambooth](#dreambooth) - [Inference](#inference) - - [Wan2.1](#wan21) - - [Wan2.2](#wan22) + - [Wan](#wan-models) - [LTX-Video](#ltx-video) - [Flux](#flux) - [Fused Attention for GPU](#fused-attention-for-gpu) - [SDXL](#stable-diffusion-xl) - [SD 2 base](#stable-diffusion-2-base) - [SD 2.1](#stable-diffusion-21) + - [Wan LoRA](#wan-lora) + - [Flux LoRA](#flux-lora) - [Hyper SDXL LoRA](#hyper-sdxl-lora) - [Load Multiple LoRA](#load-multiple-lora) - [SDXL Lightning](#sdxl-lightning) @@ -482,41 +484,48 @@ To generate images, run the following command: Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above. - ## Wan2.1 + ## Wan Models Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). - ### Text2Vid + Supports both Text2Vid and Img2Vid pipelines. - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` - - ### Img2Vid - - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` - - ## Wan2.2 - - Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). - - ### Text2Vid + The following command will run Wan2.1 T2V: ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_27b.yml attention="flash" num_inference_steps=50 num_frames=81 width=1280 height=720 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=5.0 enable_profiler=True run_name=wan-inference-testing-720p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_reduce=true" \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_14b.yml \ + attention="flash" \ + num_inference_steps=50 \ + num_frames=81 \ + width=1280 \ + height=720 \ + jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \ + per_device_batch_size=.125 \ + ici_data_parallelism=2 \ + ici_context_parallelism=2 \ + flow_shift=5.0 \ + enable_profiler=True \ + run_name=wan-inference-testing-720p \ + output_dir=gs:/jfacevedo-maxdiffusion \ + fps=16 \ + flash_min_seq_length=0 \ + flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' \ + seed=118445 ``` - ### Img2Vid + To run other Wan model inference pipelines, change the config file in the command above: - ```bash - HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ - LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_i2v_27b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=2 flow_shift=3.0 enable_profiler=True run_name=wan-i2v-inference-testing-480p output_dir=gs:/jfacevedo-maxdiffusion fps=16 flash_min_seq_length=0 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 - ``` + * For Wan2.1 I2V, use `base_wan_i2v_14b.yml`. + * For Wan2.2 T2V, use `base_wan_27b.yml`. + * For Wan2.2 I2V, use `base_wan_i2v_27b.yml`. ## Flux @@ -568,6 +577,33 @@ To generate images, run the following command: ```bash NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_dev.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt='A cute corgi lives in a house made out of sushi, anime' num_inference_steps=28 split_head_dim=True per_device_batch_size=1 attention="cudnn_flash_te" hardware=gpu ``` + ## Wan LoRA + + Disclaimer: not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If there is a specific LoRA that doesn't load, please let us know. + + First create a copy of the relevant config file eg: `src/maxdiffusion/configs/base_wan_{*}.yml`. Update the prompt and LoRA details in the config. Make sure to set `enable_lora: True`. Then run the following command: + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_reduce=true" \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_i2v_14b.yml \ # --> Change to your copy + jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \ + per_device_batch_size=.125 \ + ici_data_parallelism=2 \ + ici_context_parallelism=2 \ + run_name=wan-lora-inference-testing-720p \ + output_dir=gs:/jfacevedo-maxdiffusion \ + seed=118445 \ + enable_lora=True \ + ``` + + Loading multiple LoRAs is supported as well. ## Flux LoRA diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index f8d2ff95..2a5b0338 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -276,8 +276,8 @@ profiler_steps: 10 enable_jax_named_scopes: False # Generation parameters -prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." -prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." #"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." +prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. Appearing behind him is a giant, translucent, pink spiritual manifestation (faxiang) that is synchronized with the man's action and pose." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True height: 720 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 5aae2d5a..0bd6a27f 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -277,8 +277,8 @@ profiler_steps: 10 enable_jax_named_scopes: False # Generation parameters -prompt: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." -prompt_2: "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." +prompt_2: "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." #LoRA prompt "orbit 180 around an astronaut on the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" do_classifier_free_guidance: True height: 720