Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
[![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported
- **`2026/03/25`**: LTX-2 Video Inference is now supported
- **`2026/01/29`**: Wan LoRA for inference is now supported
- **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
Expand Down Expand Up @@ -49,6 +51,7 @@ MaxDiffusion supports
* ControlNet inference (Stable Diffusion 1.4 & SDXL).
* Dreambooth training support for Stable Diffusion 1.x,2.x.
* LTX-Video text2vid, img2vid (inference).
* LTX-2 Video text2vid (inference).
* Wan2.1 text2vid (training and inference).
* Wan2.2 text2vid (inference).

Expand All @@ -73,6 +76,7 @@ MaxDiffusion supports
- [Inference](#inference)
- [Wan](#wan-models)
- [LTX-Video](#ltx-video)
- [LTX-2 Video](#ltx-2-video)
- [Flux](#flux)
- [Fused Attention for GPU](#fused-attention-for-gpu)
- [SDXL](#stable-diffusion-xl)
Expand Down Expand Up @@ -497,6 +501,33 @@ To generate images, run the following command:

Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above.

## LTX-2 Video

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

The following command will run LTX-2 T2V:

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_ltx2.py \
src/maxdiffusion/configs/ltx2_video.yml \
attention="flash" \
num_inference_steps=40 \
num_frames=121 \
width=768 \
height=512 \
per_device_batch_size=.125 \
ici_data_parallelism=2 \
ici_context_parallelism=4 \
run_name=ltx2-inference
```

## Wan Models

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
Expand Down
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ flow_shift: 3.0
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
use_cfg_cache: False
use_magcache: False
magcache_thresh: 0.12
magcache_K: 2
retention_ratio: 0.2
mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189]

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
Expand Down
6 changes: 6 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ flow_shift: 5.0

# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
use_cfg_cache: False
use_magcache: False
magcache_thresh: 0.12
magcache_K: 2
retention_ratio: 0.2
mag_ratios_base_720p: [1.0, 1.0, 0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768]
mag_ratios_base_480p: [1.0, 1.0, 0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616]

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ dataset_name: ''
train_split: 'train'
dataset_type: 'tfrecord'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 0.125
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
Expand Down
8 changes: 8 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
use_magcache=config.use_magcache,
magcache_thresh=config.magcache_thresh,
magcache_K=config.magcache_K,
retention_ratio=config.retention_ratio,
)
elif model_key == WAN2_2:
return pipeline(
Expand Down Expand Up @@ -127,6 +131,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
use_cfg_cache=config.use_cfg_cache,
use_magcache=config.use_magcache,
magcache_thresh=config.magcache_thresh,
magcache_K=config.magcache_K,
retention_ratio=config.retention_ratio,
)
elif model_key == WAN2_2:
return pipeline(
Expand Down
103 changes: 63 additions & 40 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,11 @@ def __call__(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
deterministic: bool = True,
rngs: nnx.Rngs = None,
) -> Union[jax.Array, Dict[str, jax.Array]]:
rngs: Optional[nnx.Rngs] = None,
skip_blocks: Optional[jax.Array] = None,
cached_residual: Optional[jax.Array] = None,
return_residual: bool = False,
) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]:
hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None))
batch_size, _, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
Expand Down Expand Up @@ -628,52 +631,69 @@ def __call__(
encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1)
encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype)

if self.scan_layers:

def scan_fn(carry, block):
hidden_states_carry, rngs_carry = carry
hidden_states = block(
hidden_states_carry,
encoder_hidden_states,
timestep_proj,
rotary_emb,
deterministic,
rngs_carry,
encoder_attention_mask,
)
new_carry = (hidden_states, rngs_carry)
return new_carry, None

rematted_block_forward = self.gradient_checkpoint.apply(
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
)
initial_carry = (hidden_states, rngs)
final_carry, _ = nnx.scan(
rematted_block_forward,
length=self.num_layers,
in_axes=(nnx.Carry, 0),
out_axes=(nnx.Carry, 0),
)(initial_carry, self.blocks)

hidden_states, _ = final_carry
else:
for block in self.blocks:
def _run_all_blocks(h):
if self.scan_layers:

def layer_forward(hidden_states):
return block(
hidden_states,
def scan_fn(carry, block):
hidden_states_carry, rngs_carry = carry
hidden_states = block(
hidden_states_carry,
encoder_hidden_states,
timestep_proj,
rotary_emb,
deterministic,
rngs,
encoder_attention_mask=encoder_attention_mask,
rngs_carry,
encoder_attention_mask,
)
new_carry = (hidden_states, rngs_carry)
return new_carry, None

rematted_layer_forward = self.gradient_checkpoint.apply(
layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
rematted_block_forward = self.gradient_checkpoint.apply(
scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers
)
hidden_states = rematted_layer_forward(hidden_states)
initial_carry = (h, rngs)
final_carry, _ = nnx.scan(
rematted_block_forward,
length=self.num_layers,
in_axes=(nnx.Carry, 0),
out_axes=(nnx.Carry, 0),
)(initial_carry, self.blocks)

h_out, _ = final_carry
else:
h_out = h
for block in self.blocks:

def layer_forward(hidden_states):
return block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
deterministic,
rngs,
encoder_attention_mask=encoder_attention_mask,
)

rematted_layer_forward = self.gradient_checkpoint.apply(
layer_forward,
self.names_which_can_be_saved,
self.names_which_can_be_offloaded,
prevent_cse=not self.scan_layers,
)
h_out = rematted_layer_forward(h_out)
return h_out

hidden_states_before_blocks = hidden_states

if skip_blocks:
if cached_residual is None:
raise ValueError("cached_residual must be provided when skip_blocks is True")
hidden_states = hidden_states + cached_residual
else:
hidden_states = _run_all_blocks(hidden_states)

residual_x = hidden_states - hidden_states_before_blocks

shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
Expand All @@ -685,4 +705,7 @@ def layer_forward(hidden_states):
)
hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6))
hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width)

if return_residual:
return hidden_states, residual_x
return hidden_states
Loading
Loading