From 7e92510e092d938b9a728174b1600e9172dbf5ab Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 25 Mar 2026 14:00:58 +0530 Subject: [PATCH 1/3] WAN 2.1 MagCache support --- README.md | 31 +++ src/maxdiffusion/configs/base_wan_14b.yml | 5 + src/maxdiffusion/configs/base_wan_i2v_14b.yml | 6 + src/maxdiffusion/configs/ltx2_video.yml | 2 +- src/maxdiffusion/generate_wan.py | 8 + .../wan/transformers/transformer_wan.py | 102 ++++---- .../pipelines/wan/wan_pipeline.py | 150 +++++++++++- .../pipelines/wan/wan_pipeline_2_1.py | 166 +++++++++---- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 100 +++++++- src/maxdiffusion/tests/wan_magcache_test.py | 227 ++++++++++++++++++ 10 files changed, 698 insertions(+), 99 deletions(-) create mode 100644 src/maxdiffusion/tests/wan_magcache_test.py diff --git a/README.md b/README.md index 9319cbeee..e091e06be 100755 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ [![Unit Tests](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2026/03/25`**: Wan2.1 and Wan2.2 Magcache inference is now supported +- **`2026/03/25`**: LTX-2 Video Inference is now supported - **`2026/01/29`**: Wan LoRA for inference is now supported - **`2026/01/15`**: Wan2.1 and Wan2.2 Img2vid generation is now supported - **`2025/11/11`**: Wan2.2 txt2vid generation is now supported @@ -49,6 +51,7 @@ MaxDiffusion supports * ControlNet inference (Stable Diffusion 1.4 & SDXL). * Dreambooth training support for Stable Diffusion 1.x,2.x. * LTX-Video text2vid, img2vid (inference). +* LTX-2 Video text2vid (inference). * Wan2.1 text2vid (training and inference). * Wan2.2 text2vid (inference). @@ -73,6 +76,7 @@ MaxDiffusion supports - [Inference](#inference) - [Wan](#wan-models) - [LTX-Video](#ltx-video) + - [LTX-2 Video](#ltx-2-video) - [Flux](#flux) - [Fused Attention for GPU](#fused-attention-for-gpu) - [SDXL](#stable-diffusion-xl) @@ -497,6 +501,33 @@ To generate images, run the following command: Add conditioning image path as conditioning_media_paths in the form of ["IMAGE_PATH"] along with other generation parameters in the ltx_video.yml file. Then follow same instruction as above. + ## LTX-2 Video + + Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). + + The following command will run LTX-2 T2V: + + ```bash + HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \ + LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \ + --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \ + --xla_tpu_enable_async_collective_fusion_multiple_steps=true \ + --xla_tpu_overlap_compute_collective_tc=true \ + --xla_enable_async_all_reduce=true" \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + python src/maxdiffusion/generate_ltx2.py \ + src/maxdiffusion/configs/ltx2_video.yml \ + attention="flash" \ + num_inference_steps=40 \ + num_frames=121 \ + width=768 \ + height=512 \ + per_device_batch_size=.125 \ + ici_data_parallelism=2 \ + ici_context_parallelism=4 \ + run_name=ltx2-inference + ``` + ## Wan Models Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage). diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index fa6309610..558708ffa 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -328,6 +328,11 @@ flow_shift: 3.0 # Skips the unconditional forward pass on ~35% of steps via residual compensation. # See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2 use_cfg_cache: False +use_magcache: False +magcache_thresh: 0.12 +magcache_K: 2 +retention_ratio: 0.2 +mag_ratios_base: [1.0, 1.0, 1.02504, 1.03017, 1.00025, 1.00251, 0.9985, 0.99962, 0.99779, 0.99771, 0.9966, 0.99658, 0.99482, 0.99476, 0.99467, 0.99451, 0.99664, 0.99656, 0.99434, 0.99431, 0.99533, 0.99545, 0.99468, 0.99465, 0.99438, 0.99434, 0.99516, 0.99517, 0.99384, 0.9938, 0.99404, 0.99401, 0.99517, 0.99516, 0.99409, 0.99408, 0.99428, 0.99426, 0.99347, 0.99343, 0.99418, 0.99416, 0.99271, 0.99269, 0.99313, 0.99311, 0.99215, 0.99215, 0.99218, 0.99215, 0.99216, 0.99217, 0.99163, 0.99161, 0.99138, 0.99135, 0.98982, 0.9898, 0.98996, 0.98995, 0.9887, 0.98866, 0.98772, 0.9877, 0.98767, 0.98765, 0.98573, 0.9857, 0.98501, 0.98498, 0.9838, 0.98376, 0.98177, 0.98173, 0.98037, 0.98035, 0.97678, 0.97677, 0.97546, 0.97543, 0.97184, 0.97183, 0.96711, 0.96708, 0.96349, 0.96345, 0.95629, 0.95625, 0.94926, 0.94929, 0.93964, 0.93961, 0.92511, 0.92504, 0.90693, 0.90678, 0.8796, 0.87945, 0.86111, 0.86189] # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index d0c1a0140..e2170293e 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -288,6 +288,12 @@ flow_shift: 5.0 # Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) use_cfg_cache: False +use_magcache: False +magcache_thresh: 0.12 +magcache_K: 2 +retention_ratio: 0.2 +mag_ratios_base_720p: [1.0, 1.0, 0.99428, 0.99498, 0.98588, 0.98621, 0.98273, 0.98281, 0.99018, 0.99023, 0.98911, 0.98917, 0.98646, 0.98652, 0.99454, 0.99456, 0.9891, 0.98909, 0.99124, 0.99127, 0.99102, 0.99103, 0.99215, 0.99212, 0.99515, 0.99515, 0.99576, 0.99572, 0.99068, 0.99072, 0.99097, 0.99097, 0.99166, 0.99169, 0.99041, 0.99042, 0.99201, 0.99198, 0.99101, 0.99101, 0.98599, 0.98603, 0.98845, 0.98844, 0.98848, 0.98851, 0.98862, 0.98857, 0.98718, 0.98719, 0.98497, 0.98497, 0.98264, 0.98263, 0.98389, 0.98393, 0.97938, 0.9794, 0.97535, 0.97536, 0.97498, 0.97499, 0.973, 0.97301, 0.96827, 0.96828, 0.96261, 0.96263, 0.95335, 0.9534, 0.94649, 0.94655, 0.93397, 0.93414, 0.91636, 0.9165, 0.89088, 0.89109, 0.8679, 0.86768] +mag_ratios_base_480p: [1.0, 1.0, 0.98783, 0.98993, 0.97559, 0.97593, 0.98311, 0.98319, 0.98202, 0.98225, 0.9888, 0.98878, 0.98762, 0.98759, 0.98957, 0.98971, 0.99052, 0.99043, 0.99383, 0.99384, 0.98857, 0.9886, 0.99065, 0.99068, 0.98845, 0.98847, 0.99057, 0.99057, 0.98957, 0.98961, 0.98601, 0.9861, 0.98823, 0.98823, 0.98756, 0.98759, 0.98808, 0.98814, 0.98721, 0.98724, 0.98571, 0.98572, 0.98543, 0.98544, 0.98157, 0.98165, 0.98411, 0.98413, 0.97952, 0.97953, 0.98149, 0.9815, 0.9774, 0.97742, 0.97825, 0.97826, 0.97355, 0.97361, 0.97085, 0.97087, 0.97056, 0.97055, 0.96588, 0.96587, 0.96113, 0.96124, 0.9567, 0.95681, 0.94961, 0.94969, 0.93973, 0.93988, 0.93217, 0.93224, 0.91878, 0.91896, 0.90955, 0.90954, 0.92617, 0.92616] # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 5dff87449..5a9ef5210 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -80,7 +80,7 @@ dataset_name: '' train_split: 'train' dataset_type: 'tfrecord' cache_latents_text_encoder_outputs: True -per_device_batch_size: 0.125 +per_device_batch_size: 1 compile_topology_num_slices: -1 quantization_local_shard_count: -1 use_qwix_quantization: False diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 56c5a8a07..3cbfb60ee 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -100,6 +100,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, + use_magcache=config.use_magcache, + magcache_thresh=config.magcache_thresh, + magcache_K=config.magcache_K, + retention_ratio=config.retention_ratio, ) elif model_key == WAN2_2: return pipeline( @@ -127,6 +131,10 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, use_cfg_cache=config.use_cfg_cache, + use_magcache=config.use_magcache, + magcache_thresh=config.magcache_thresh, + magcache_K=config.magcache_K, + retention_ratio=config.retention_ratio, ) elif model_key == WAN2_2: return pipeline( diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index e701ab92c..934167ad7 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -593,8 +593,11 @@ def __call__( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, deterministic: bool = True, - rngs: nnx.Rngs = None, - ) -> Union[jax.Array, Dict[str, jax.Array]]: + rngs: Optional[nnx.Rngs] = None, + skip_blocks: Optional[jax.Array] = None, + cached_residual: Optional[jax.Array] = None, + return_residual: bool = False, + ) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]: hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, _, num_frames, height, width = hidden_states.shape p_t, p_h, p_w = self.config.patch_size @@ -628,52 +631,66 @@ def __call__( encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) - if self.scan_layers: - - def scan_fn(carry, block): - hidden_states_carry, rngs_carry = carry - hidden_states = block( - hidden_states_carry, - encoder_hidden_states, - timestep_proj, - rotary_emb, - deterministic, - rngs_carry, - encoder_attention_mask, - ) - new_carry = (hidden_states, rngs_carry) - return new_carry, None - - rematted_block_forward = self.gradient_checkpoint.apply( - scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers - ) - initial_carry = (hidden_states, rngs) - final_carry, _ = nnx.scan( - rematted_block_forward, - length=self.num_layers, - in_axes=(nnx.Carry, 0), - out_axes=(nnx.Carry, 0), - )(initial_carry, self.blocks) - - hidden_states, _ = final_carry - else: - for block in self.blocks: + def _run_all_blocks(h): + if self.scan_layers: - def layer_forward(hidden_states): - return block( - hidden_states, + def scan_fn(carry, block): + hidden_states_carry, rngs_carry = carry + hidden_states = block( + hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, - rngs, - encoder_attention_mask=encoder_attention_mask, + rngs_carry, + encoder_attention_mask, ) + new_carry = (hidden_states, rngs_carry) + return new_carry, None + + rematted_block_forward = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + initial_carry = (h, rngs) + final_carry, _ = nnx.scan( + rematted_block_forward, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + )(initial_carry, self.blocks) + + h_out, _ = final_carry + else: + h_out = h + for block in self.blocks: + + def layer_forward(hidden_states): + return block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs, + encoder_attention_mask=encoder_attention_mask, + ) + + rematted_layer_forward = self.gradient_checkpoint.apply( + layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + h_out = rematted_layer_forward(h_out) + return h_out - rematted_layer_forward = self.gradient_checkpoint.apply( - layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers - ) - hidden_states = rematted_layer_forward(hidden_states) + hidden_states_before_blocks = hidden_states + + if skip_blocks: + if cached_residual is None: + raise ValueError("cached_residual must be provided when skip_blocks is True") + hidden_states = hidden_states + cached_residual + else: + hidden_states = _run_all_blocks(hidden_states) + + residual_x = hidden_states - hidden_states_before_blocks shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) @@ -685,4 +702,7 @@ def layer_forward(hidden_states): ) hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if return_residual: + return hidden_states, residual_x return hidden_states diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 86c9f9c2e..3aae43021 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -750,7 +750,7 @@ def __call__(self, **kwargs): pass -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) +@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale", "return_residual", "skip_blocks")) def transformer_forward_pass( graphdef, sharded_state, @@ -761,14 +761,26 @@ def transformer_forward_pass( do_classifier_free_guidance, guidance_scale, encoder_hidden_states_image=None, + skip_blocks=None, + cached_residual=None, + return_residual=False, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer( + outputs = wan_transformer( hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=encoder_hidden_states_image, + skip_blocks=skip_blocks, + cached_residual=cached_residual, + return_residual=return_residual, ) + + if return_residual: + noise_pred, residual_x = outputs + else: + noise_pred = outputs + if do_classifier_free_guidance: bsz = latents.shape[0] // 2 noise_cond = noise_pred[:bsz] # First half = conditional @@ -777,6 +789,8 @@ def transformer_forward_pass( latents = latents[:bsz] + if return_residual: + return noise_pred, latents, residual_x return noise_pred, latents @@ -805,6 +819,9 @@ def transformer_forward_pass_full_cfg( timestep=timestep, encoder_hidden_states=prompt_embeds_combined, encoder_hidden_states_image=encoder_hidden_states_image, + skip_blocks=False, + cached_residual=None, + return_residual=False, ) noise_cond = noise_pred[:bsz] noise_uncond = noise_pred[bsz:] @@ -881,3 +898,132 @@ def transformer_forward_pass_cfg_cache( noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx) return noise_pred_merged, noise_cond + +def nearest_interp(src, target_len): + """Nearest neighbor interpolation for ratio scaling layout.""" + src_len = len(src) + if target_len == 1: + import numpy as np + return np.array([src[-1]]) + import numpy as np + indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32) + return src[indices] + +def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base): + """Initialize MagCache variables and interpolate ratios. + + Args: + num_inference_steps: Number of inference steps. + retention_ratio: Retention ratio of unchanged steps. + mag_ratios_base: Base magnitude ratios array or list. + """ + import numpy as np + + accumulated_ratio_cond = 1.0 + accumulated_ratio_uncond = 1.0 + accumulated_err_cond = 0.0 + accumulated_err_uncond = 0.0 + accumulated_steps_cond = 0 + accumulated_steps_uncond = 0 + cached_residual = None + + skip_warmup = int(num_inference_steps * retention_ratio) + + mag_ratios_base = np.array(mag_ratios_base) + + if len(mag_ratios_base) != num_inference_steps * 2: + mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps) + mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps) + mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1) + else: + mag_ratios = mag_ratios_base + + return ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + cached_residual, + skip_warmup, + mag_ratios, + ) + +def magcache_step( + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + skip_warmup=0, + use_magcache=None, +): + """Update MagCache accumulated state and decide if to skip. + + Args: + step: Current inference step. + mag_ratios: Interpolated magnitude ratios array. + accumulated_state: Tuple containing accumulated variables. + magcache_thresh: Error threshold. + magcache_K: Max skip steps. + skip_warmup: Warmup steps threshold. + use_magcache: Optional manual override boolean to enable/disable cache for this step. + """ + import numpy as np + + ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) = accumulated_state + + cur_mag_ratio_cond = mag_ratios[step * 2] + cur_mag_ratio_uncond = mag_ratios[step * 2 + 1] + + if use_magcache is None: + use_magcache = True + if step < skip_warmup: + use_magcache = False + + skip_blocks = False + if use_magcache: + new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond + new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond + + err_cond = np.abs(1.0 - new_ratio_cond) + err_uncond = np.abs(1.0 - new_ratio_uncond) + + if ( + accumulated_err_cond + err_cond < magcache_thresh + and accumulated_steps_cond < magcache_K + and accumulated_err_uncond + err_uncond < magcache_thresh + and accumulated_steps_uncond < magcache_K + ): + skip_blocks = True + accumulated_ratio_cond = new_ratio_cond + accumulated_ratio_uncond = new_ratio_uncond + accumulated_err_cond += err_cond + accumulated_err_uncond += err_uncond + accumulated_steps_cond += 1 + accumulated_steps_uncond += 1 + else: + accumulated_ratio_cond = 1.0 + accumulated_ratio_uncond = 1.0 + accumulated_err_cond = 0.0 + accumulated_err_uncond = 0.0 + accumulated_steps_cond = 0 + accumulated_steps_uncond = 0 + + new_state = ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) + return skip_blocks, new_state diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 976f0f042..d05bc8e99 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp, init_magcache, magcache_step from ...models.wan.transformers.transformer_wan import WanModel -from typing import List, Union, Optional +from typing import List, Union, Optional, Any from ...pyconfig import HyperParameters from functools import partial from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp +import numpy as np from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler @@ -91,7 +92,18 @@ def __call__( negative_prompt_embeds: Optional[jax.Array] = None, vae_only: bool = False, use_cfg_cache: bool = False, + use_magcache: bool = False, + magcache_thresh: Optional[float] = None, + magcache_K: Optional[int] = None, + retention_ratio: Optional[float] = None, ): + if magcache_thresh is None: + magcache_thresh = getattr(self.config, "magcache_thresh", 0.12) + if magcache_K is None: + magcache_K = getattr(self.config, "magcache_K", 2) + if retention_ratio is None: + retention_ratio = getattr(self.config, "retention_ratio", 0.2) + if use_cfg_cache and guidance_scale <= 1.0: raise ValueError( f"use_cfg_cache=True requires guidance_scale > 1.0 (got {guidance_scale}). " @@ -122,7 +134,12 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, use_cfg_cache=use_cfg_cache, + use_magcache=use_magcache, + magcache_thresh=magcache_thresh, + magcache_K=magcache_K, + retention_ratio=retention_ratio, height=height, + mag_ratios_base=self.config.mag_ratios_base if hasattr(self.config, "mag_ratios_base") else None, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -150,7 +167,12 @@ def run_inference_2_1( scheduler: FlaxUniPCMultistepScheduler, scheduler_state, use_cfg_cache: bool = False, + use_magcache: bool = False, + magcache_thresh: float = 0.12, + magcache_K: int = 2, + retention_ratio: float = 0.2, height: int = 480, + mag_ratios_base: Optional[List[float]] = None, ): """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. @@ -222,55 +244,113 @@ def run_inference_2_1( cached_noise_cond = None cached_noise_uncond = None - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - is_cache_step = step_is_cache[step] + if use_magcache and do_cfg: + ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + cached_residual, + skip_warmup, + mag_ratios, + ) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) - if is_cache_step: - # ── Cache step: cond-only forward + FFT frequency compensation ── - w1, w2 = step_w1w2[step] - timestep = jnp.broadcast_to(t, bsz) - noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_cond_embeds, - cached_noise_cond, - cached_noise_uncond, - guidance_scale=guidance_scale, - w1=jnp.float32(w1), - w2=jnp.float32(w2), - ) + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz) - elif do_cfg: - # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── - latents_doubled = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( - graphdef, - sharded_state, - rest_of_state, - latents_doubled, - timestep, - prompt_embeds_combined, - guidance_scale=guidance_scale, + accumulated_state = ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) + skip_blocks, accumulated_state = magcache_step( + step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup ) + ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) = accumulated_state - else: - # ── No CFG (guidance_scale <= 1.0) ── - timestep = jnp.broadcast_to(t, bsz) - noise_pred, latents = transformer_forward_pass( + outputs = transformer_forward_pass( graphdef, sharded_state, rest_of_state, - latents, + jnp.concatenate([latents] * 2) if do_cfg else latents, timestep, - prompt_cond_embeds, - do_classifier_free_guidance=False, + prompt_embeds_combined if do_cfg else prompt_cond_embeds, + do_classifier_free_guidance=do_cfg, guidance_scale=guidance_scale, + skip_blocks=bool(skip_blocks), + cached_residual=cached_residual, + return_residual=True, ) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents + noise_pred, latents_returned, residual_x_cur = outputs + + if not skip_blocks: + cached_residual = residual_x_cur + + latents = latents_returned + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents + + else: + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + is_cache_step = step_is_cache[step] + + if is_cache_step: + w1, w2 = step_w1w2[step] + timestep = jnp.broadcast_to(t, bsz) + noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_cond_embeds, + cached_noise_cond, + cached_noise_uncond, + guidance_scale=guidance_scale, + w1=jnp.float32(w1), + w2=jnp.float32(w2), + ) + + elif do_cfg: + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + graphdef, + sharded_state, + rest_of_state, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + + else: + timestep = jnp.broadcast_to(t, bsz) + noise_pred, latents = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_cond_embeds, + do_classifier_free_guidance=False, + guidance_scale=guidance_scale, + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0622ec79b..4bc85856e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -14,12 +14,13 @@ from maxdiffusion import max_logging from maxdiffusion.image_processor import PipelineImageInput -from .wan_pipeline import WanPipeline, transformer_forward_pass +from .wan_pipeline import WanPipeline, transformer_forward_pass, nearest_interp, init_magcache, magcache_step from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional, Tuple from ...pyconfig import HyperParameters from functools import partial from flax import nnx +import numpy as np from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp @@ -149,7 +150,18 @@ def __call__( last_image: Optional[PipelineImageInput] = None, output_type: Optional[str] = "np", rng: Optional[jax.Array] = None, + use_magcache: bool = False, + magcache_thresh: Optional[float] = None, + magcache_K: Optional[int] = None, + retention_ratio: Optional[float] = None, ): + if magcache_thresh is None: + magcache_thresh = getattr(self.config, "magcache_thresh", 0.04) + if magcache_K is None: + magcache_K = getattr(self.config, "magcache_K", 2) + if retention_ratio is None: + retention_ratio = getattr(self.config, "retention_ratio", 0.2) + height = height or self.config.height width = width or self.config.width num_frames = num_frames or self.config.num_frames @@ -232,6 +244,12 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, scheduler=self.scheduler, + use_magcache=use_magcache, + magcache_thresh=magcache_thresh, + magcache_K=magcache_K, + retention_ratio=retention_ratio, + height=height, + mag_ratios_base=self.config.mag_ratios_base_720p if height >= 720 else self.config.mag_ratios_base_480p, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -264,33 +282,91 @@ def run_inference_2_1_i2v( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, + use_magcache: bool = False, + magcache_thresh: float = 0.04, + magcache_K: int = 2, + retention_ratio: float = 0.2, + height: int = 480, + mag_ratios_base: Optional[List[float]] = None, ): - do_classifier_free_guidance = guidance_scale > 1.0 + do_cfg = guidance_scale > 1.0 + + if use_magcache and do_cfg: + ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + cached_residual, + skip_warmup, + mag_ratios, + ) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) + + if do_cfg: + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) + condition_combined = jnp.concatenate([condition] * 2) + else: + prompt_embeds_combined = prompt_embeds + image_embeds_combined = image_embeds + condition_combined = condition - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) - condition = jnp.concatenate([condition] * 2) for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + + skip_blocks = False + if use_magcache and do_cfg: + accumulated_state = ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) + skip_blocks, accumulated_state = magcache_step( + step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup + ) + ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) = accumulated_state + latents_input = latents - if do_classifier_free_guidance: + if do_cfg: latents_input = jnp.concatenate([latents, latents], axis=0) - latent_model_input = jnp.concatenate([latents_input, condition], axis=-1) + latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) - noise_pred, _ = transformer_forward_pass( + + outputs = transformer_forward_pass( graphdef, sharded_state, rest_of_state, latent_model_input, timestep, - prompt_embeds, - do_classifier_free_guidance=do_classifier_free_guidance, + prompt_embeds_combined, + do_classifier_free_guidance=do_cfg, guidance_scale=guidance_scale, - encoder_hidden_states_image=image_embeds, + encoder_hidden_states_image=image_embeds_combined, + skip_blocks=bool(skip_blocks) if use_magcache and do_cfg else None, + cached_residual=cached_residual if use_magcache and do_cfg else None, + return_residual=True if use_magcache and do_cfg else False, ) + if use_magcache and do_cfg: + noise_pred, _, residual_x_cur = outputs + if not skip_blocks: + cached_residual = residual_x_cur + else: + noise_pred, _ = outputs + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False) return latents diff --git a/src/maxdiffusion/tests/wan_magcache_test.py b/src/maxdiffusion/tests/wan_magcache_test.py new file mode 100644 index 000000000..6d6c81c76 --- /dev/null +++ b/src/maxdiffusion/tests/wan_magcache_test.py @@ -0,0 +1,227 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import time +import unittest + +import numpy as np +import pytest + +from maxdiffusion import pyconfig +from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1 +from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p1 import WanCheckpointerI2V_2_1 +from maxdiffusion.utils.loading_utils import load_image +import jax + +try: + jax.distributed.initialize() +except Exception: + pass + +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def calculate_metrics(v1_baseline, v2_cached): + """Helper to calculate Speedup, PSNR and SSIM between baseline and cached videos.""" + num_videos = len(v1_baseline) + all_psnr = [] + all_ssim = [] + + for i in range(num_videos): + v1 = np.array(v1_baseline[i], dtype=np.float64) + v2 = np.array(v2_cached[i], dtype=np.float64) + + # PSNR + mse = np.mean((v1 - v2) ** 2) + psnr = 10.0 * np.log10(1.0 / mse) if mse > 0 else float("inf") + all_psnr.append(psnr) + + # SSIM (per-frame) + C1, C2 = 0.01**2, 0.03**2 + ssim_scores = [] + for f in range(v1.shape[0]): + mu1, mu2 = np.mean(v1[f]), np.mean(v2[f]) + sigma1_sq, sigma2_sq = np.var(v1[f]), np.var(v2[f]) + sigma12 = np.mean((v1[f] - mu1) * (v2[f] - mu2)) + ssim = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)) + ssim_scores.append(float(ssim)) + + mean_ssim = np.mean(ssim_scores) + all_ssim.append(mean_ssim) + + avg_psnr = np.mean(all_psnr) + avg_ssim = np.mean(all_ssim) + print(f"PSNR (avg of {num_videos} videos): {avg_psnr:.2f} dB") + print(f"SSIM (avg of {num_videos} videos): mean={avg_ssim:.4f}") + return avg_psnr, avg_ssim + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class Wan21T2VMagCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: MagCache for Wan 2.1 T2V 14B.""" + + @classmethod + def setUpClass(cls): + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "num_frames=81", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=2", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=4", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q":3024,"block_kv_compute":1024,"block_kv":2048,"block_q_dkv":3024,"block_kv_dkv":2048,"block_kv_dkv_compute":1024,"use_fused_bwd_kernel":true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointer2_1(cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + print("Warming up paths...") + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + use_magcache=use_cache, + ) + + def _run_pipeline(self, use_magcache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + use_magcache=use_magcache, + ) + return videos, time.perf_counter() - t0 + + def test_magcache_speedup_and_fidelity(self): + videos_baseline, t_baseline = self._run_pipeline(use_magcache=False) + videos_cached, t_cached = self._run_pipeline(use_magcache=True) + + speedup = t_baseline / t_cached + print(f"[Wan 2.1 T2V 14B] Baseline: {t_baseline:.2f}s, MagCache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + psnr, ssim = calculate_metrics(videos_baseline, videos_cached) + + self.assertGreaterEqual(ssim, 0.98) + self.assertGreater(speedup, 1.0) + self.assertGreaterEqual(psnr, 30.0) + + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Requires TPU v7-8 and model weights") +class Wan21I2VMagCacheSmokeTest(unittest.TestCase): + """End-to-end smoke test: MagCache for Wan 2.1 I2V 14B.""" + + @classmethod + def setUpClass(cls): + + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_i2v_14b.yml"), + "num_inference_steps=50", + "height=720", + "width=1280", + "pretrained_model_name_or_path=Wan-AI/Wan2.1-I2V-14B-720P-Diffusers", + "num_frames=81", + "seed=11234567893", + "attention=flash", + "remat_policy=FULL", + "allow_split_physical_axes=True", + "weights_dtype=bfloat16", + "activations_dtype=bfloat16", + "per_device_batch_size=0.125", + "ici_data_parallelism=2", + "ici_fsdp_parallelism=1", + "ici_context_parallelism=4", + "ici_tensor_parallelism=1", + "flash_min_seq_length=0", + 'flash_block_sizes={"block_q":3024,"block_kv_compute":1024,"block_kv":2048,"block_q_dkv":3024,"block_kv_dkv":2048,"block_kv_dkv_compute":1024,"use_fused_bwd_kernel":true}', + ], + unittest=True, + ) + cls.config = pyconfig.config + checkpoint_loader = WanCheckpointerI2V_2_1(cls.config) + cls.pipeline, _, _ = checkpoint_loader.load_checkpoint() + + cls.image = load_image(cls.config.image_url) + cls.prompt = [cls.config.prompt] * cls.config.global_batch_size_to_train_on + cls.negative_prompt = [cls.config.negative_prompt] * cls.config.global_batch_size_to_train_on + + print("Warming up paths...") + for use_cache in [False, True]: + cls.pipeline( + prompt=cls.prompt, + image=cls.image, + negative_prompt=cls.negative_prompt, + height=cls.config.height, + width=cls.config.width, + num_frames=cls.config.num_frames, + num_inference_steps=cls.config.num_inference_steps, + use_magcache=use_cache, + ) + + def _run_pipeline(self, use_magcache): + t0 = time.perf_counter() + videos = self.pipeline( + prompt=self.prompt, + image=self.image, + negative_prompt=self.negative_prompt, + height=self.config.height, + width=self.config.width, + num_frames=self.config.num_frames, + num_inference_steps=self.config.num_inference_steps, + use_magcache=use_magcache, + ) + return videos, time.perf_counter() - t0 + + def test_magcache_speedup_and_fidelity(self): + videos_baseline, t_baseline = self._run_pipeline(use_magcache=False) + videos_cached, t_cached = self._run_pipeline(use_magcache=True) + + speedup = t_baseline / t_cached + print(f"[Wan 2.1 I2V 14B] Baseline: {t_baseline:.2f}s, MagCache: {t_cached:.2f}s, Speedup: {speedup:.3f}x") + psnr, ssim = calculate_metrics(videos_baseline, videos_cached) + + self.assertGreaterEqual(ssim, 0.98) + self.assertGreater(speedup, 1.0) + self.assertGreaterEqual(psnr, 30.0) + From 00a19e2e10e71bbfad832153c22b0b52c1f8e2ce Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 25 Mar 2026 08:32:30 +0000 Subject: [PATCH 2/3] reformatted --- .../wan/transformers/transformer_wan.py | 9 +- .../pipelines/wan/wan_pipeline.py | 235 +++++++++--------- .../pipelines/wan/wan_pipeline_2_1.py | 5 +- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 7 +- src/maxdiffusion/tests/wan_magcache_test.py | 2 - 5 files changed, 131 insertions(+), 127 deletions(-) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 934167ad7..11d7cad2b 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -650,7 +650,7 @@ def scan_fn(carry, block): rematted_block_forward = self.gradient_checkpoint.apply( scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers - ) + ) initial_carry = (h, rngs) final_carry, _ = nnx.scan( rematted_block_forward, @@ -676,7 +676,10 @@ def layer_forward(hidden_states): ) rematted_layer_forward = self.gradient_checkpoint.apply( - layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + layer_forward, + self.names_which_can_be_saved, + self.names_which_can_be_offloaded, + prevent_cse=not self.scan_layers, ) h_out = rematted_layer_forward(h_out) return h_out @@ -702,7 +705,7 @@ def layer_forward(hidden_states): ) hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width) - + if return_residual: return hidden_states, residual_x return hidden_states diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 3aae43021..50b37c626 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -775,7 +775,7 @@ def transformer_forward_pass( cached_residual=cached_residual, return_residual=return_residual, ) - + if return_residual: noise_pred, residual_x = outputs else: @@ -899,56 +899,61 @@ def transformer_forward_pass_cfg_cache( noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx) return noise_pred_merged, noise_cond + def nearest_interp(src, target_len): - """Nearest neighbor interpolation for ratio scaling layout.""" - src_len = len(src) - if target_len == 1: - import numpy as np - return np.array([src[-1]]) + """Nearest neighbor interpolation for ratio scaling layout.""" + src_len = len(src) + if target_len == 1: import numpy as np - indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32) - return src[indices] + + return np.array([src[-1]]) + import numpy as np + + indices = np.round(np.linspace(0, src_len - 1, target_len)).astype(np.int32) + return src[indices] + def init_magcache(num_inference_steps, retention_ratio, mag_ratios_base): - """Initialize MagCache variables and interpolate ratios. - - Args: - num_inference_steps: Number of inference steps. - retention_ratio: Retention ratio of unchanged steps. - mag_ratios_base: Base magnitude ratios array or list. - """ - import numpy as np - - accumulated_ratio_cond = 1.0 - accumulated_ratio_uncond = 1.0 - accumulated_err_cond = 0.0 - accumulated_err_uncond = 0.0 - accumulated_steps_cond = 0 - accumulated_steps_uncond = 0 - cached_residual = None - - skip_warmup = int(num_inference_steps * retention_ratio) - - mag_ratios_base = np.array(mag_ratios_base) - - if len(mag_ratios_base) != num_inference_steps * 2: - mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps) - mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps) - mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1) - else: - mag_ratios = mag_ratios_base - - return ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - cached_residual, - skip_warmup, - mag_ratios, - ) + """Initialize MagCache variables and interpolate ratios. + + Args: + num_inference_steps: Number of inference steps. + retention_ratio: Retention ratio of unchanged steps. + mag_ratios_base: Base magnitude ratios array or list. + """ + import numpy as np + + accumulated_ratio_cond = 1.0 + accumulated_ratio_uncond = 1.0 + accumulated_err_cond = 0.0 + accumulated_err_uncond = 0.0 + accumulated_steps_cond = 0 + accumulated_steps_uncond = 0 + cached_residual = None + + skip_warmup = int(num_inference_steps * retention_ratio) + + mag_ratios_base = np.array(mag_ratios_base) + + if len(mag_ratios_base) != num_inference_steps * 2: + mag_cond = nearest_interp(mag_ratios_base[0::2], num_inference_steps) + mag_uncond = nearest_interp(mag_ratios_base[1::2], num_inference_steps) + mag_ratios = np.concatenate([mag_cond.reshape(-1, 1), mag_uncond.reshape(-1, 1)], axis=1).reshape(-1) + else: + mag_ratios = mag_ratios_base + + return ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + cached_residual, + skip_warmup, + mag_ratios, + ) + def magcache_step( step, @@ -959,71 +964,71 @@ def magcache_step( skip_warmup=0, use_magcache=None, ): - """Update MagCache accumulated state and decide if to skip. - - Args: - step: Current inference step. - mag_ratios: Interpolated magnitude ratios array. - accumulated_state: Tuple containing accumulated variables. - magcache_thresh: Error threshold. - magcache_K: Max skip steps. - skip_warmup: Warmup steps threshold. - use_magcache: Optional manual override boolean to enable/disable cache for this step. - """ - import numpy as np - - ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - ) = accumulated_state - - cur_mag_ratio_cond = mag_ratios[step * 2] - cur_mag_ratio_uncond = mag_ratios[step * 2 + 1] - - if use_magcache is None: - use_magcache = True - if step < skip_warmup: - use_magcache = False - - skip_blocks = False - if use_magcache: - new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond - new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond - - err_cond = np.abs(1.0 - new_ratio_cond) - err_uncond = np.abs(1.0 - new_ratio_uncond) - - if ( - accumulated_err_cond + err_cond < magcache_thresh - and accumulated_steps_cond < magcache_K - and accumulated_err_uncond + err_uncond < magcache_thresh - and accumulated_steps_uncond < magcache_K - ): - skip_blocks = True - accumulated_ratio_cond = new_ratio_cond - accumulated_ratio_uncond = new_ratio_uncond - accumulated_err_cond += err_cond - accumulated_err_uncond += err_uncond - accumulated_steps_cond += 1 - accumulated_steps_uncond += 1 - else: - accumulated_ratio_cond = 1.0 - accumulated_ratio_uncond = 1.0 - accumulated_err_cond = 0.0 - accumulated_err_uncond = 0.0 - accumulated_steps_cond = 0 - accumulated_steps_uncond = 0 - - new_state = ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - ) - return skip_blocks, new_state + """Update MagCache accumulated state and decide if to skip. + + Args: + step: Current inference step. + mag_ratios: Interpolated magnitude ratios array. + accumulated_state: Tuple containing accumulated variables. + magcache_thresh: Error threshold. + magcache_K: Max skip steps. + skip_warmup: Warmup steps threshold. + use_magcache: Optional manual override boolean to enable/disable cache for this step. + """ + import numpy as np + + ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) = accumulated_state + + cur_mag_ratio_cond = mag_ratios[step * 2] + cur_mag_ratio_uncond = mag_ratios[step * 2 + 1] + + if use_magcache is None: + use_magcache = True + if step < skip_warmup: + use_magcache = False + + skip_blocks = False + if use_magcache: + new_ratio_cond = accumulated_ratio_cond * cur_mag_ratio_cond + new_ratio_uncond = accumulated_ratio_uncond * cur_mag_ratio_uncond + + err_cond = np.abs(1.0 - new_ratio_cond) + err_uncond = np.abs(1.0 - new_ratio_uncond) + + if ( + accumulated_err_cond + err_cond < magcache_thresh + and accumulated_steps_cond < magcache_K + and accumulated_err_uncond + err_uncond < magcache_thresh + and accumulated_steps_uncond < magcache_K + ): + skip_blocks = True + accumulated_ratio_cond = new_ratio_cond + accumulated_ratio_uncond = new_ratio_uncond + accumulated_err_cond += err_cond + accumulated_err_uncond += err_uncond + accumulated_steps_cond += 1 + accumulated_steps_uncond += 1 + else: + accumulated_ratio_cond = 1.0 + accumulated_ratio_uncond = 1.0 + accumulated_err_cond = 0.0 + accumulated_err_uncond = 0.0 + accumulated_steps_cond = 0 + accumulated_steps_uncond = 0 + + new_state = ( + accumulated_ratio_cond, + accumulated_ratio_uncond, + accumulated_err_cond, + accumulated_err_uncond, + accumulated_steps_cond, + accumulated_steps_uncond, + ) + return skip_blocks, new_state diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index d05bc8e99..19269e5a9 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, nearest_interp, init_magcache, magcache_step +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache, init_magcache, magcache_step from ...models.wan.transformers.transformer_wan import WanModel -from typing import List, Union, Optional, Any +from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp -import numpy as np from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 4bc85856e..baaf1c2cc 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -14,13 +14,12 @@ from maxdiffusion import max_logging from maxdiffusion.image_processor import PipelineImageInput -from .wan_pipeline import WanPipeline, transformer_forward_pass, nearest_interp, init_magcache, magcache_step +from .wan_pipeline import WanPipeline, transformer_forward_pass, init_magcache, magcache_step from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional, Tuple from ...pyconfig import HyperParameters from functools import partial from flax import nnx -import numpy as np from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp @@ -315,7 +314,7 @@ def run_inference_2_1_i2v( for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - + skip_blocks = False if use_magcache and do_cfg: accumulated_state = ( @@ -345,7 +344,7 @@ def run_inference_2_1_i2v( latent_model_input = jnp.concatenate([latents_input, condition_combined], axis=-1) timestep = jnp.broadcast_to(t, latents_input.shape[0]) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) - + outputs = transformer_forward_pass( graphdef, sharded_state, diff --git a/src/maxdiffusion/tests/wan_magcache_test.py b/src/maxdiffusion/tests/wan_magcache_test.py index 6d6c81c76..6413582b3 100644 --- a/src/maxdiffusion/tests/wan_magcache_test.py +++ b/src/maxdiffusion/tests/wan_magcache_test.py @@ -152,7 +152,6 @@ class Wan21I2VMagCacheSmokeTest(unittest.TestCase): @classmethod def setUpClass(cls): - pyconfig.initialize( [ None, @@ -224,4 +223,3 @@ def test_magcache_speedup_and_fidelity(self): self.assertGreaterEqual(ssim, 0.98) self.assertGreater(speedup, 1.0) self.assertGreaterEqual(psnr, 30.0) - From 2a82a6e379ef9c262dbe514f5a6a320363585f57 Mon Sep 17 00:00:00 2001 From: Prisha Jain Date: Wed, 25 Mar 2026 21:37:09 +0530 Subject: [PATCH 3/3] fix and refactor --- .../pipelines/wan/wan_pipeline_2_1.py | 68 ++++++------------- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 39 +++-------- 2 files changed, 30 insertions(+), 77 deletions(-) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 19269e5a9..589ab6076 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -96,12 +96,13 @@ def __call__( magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, ): + config = getattr(self, "config", None) if magcache_thresh is None: - magcache_thresh = getattr(self.config, "magcache_thresh", 0.12) + magcache_thresh = getattr(config, "magcache_thresh", 0.12) if magcache_K is None: - magcache_K = getattr(self.config, "magcache_K", 2) + magcache_K = getattr(config, "magcache_K", 2) if retention_ratio is None: - retention_ratio = getattr(self.config, "retention_ratio", 0.2) + retention_ratio = getattr(config, "retention_ratio", 0.2) if use_cfg_cache and guidance_scale <= 1.0: raise ValueError( @@ -138,7 +139,7 @@ def __call__( magcache_K=magcache_K, retention_ratio=retention_ratio, height=height, - mag_ratios_base=self.config.mag_ratios_base if hasattr(self.config, "mag_ratios_base") else None, + mag_ratios_base=getattr(config, "mag_ratios_base", None), ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -244,43 +245,23 @@ def run_inference_2_1( cached_noise_uncond = None if use_magcache and do_cfg: - ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - cached_residual, - skip_warmup, - mag_ratios, - ) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) + accumulated_state = magcache_init[:6] + cached_residual = magcache_init[6] + skip_warmup = magcache_init[7] + mag_ratios = magcache_init[8] + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + + if use_magcache and do_cfg: timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz) - accumulated_state = ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - ) skip_blocks, accumulated_state = magcache_step( step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup ) - ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - ) = accumulated_state - - outputs = transformer_forward_pass( + + noise_pred, latents, residual_x_cur = transformer_forward_pass( graphdef, sharded_state, rest_of_state, @@ -294,18 +275,10 @@ def run_inference_2_1( return_residual=True, ) - noise_pred, latents_returned, residual_x_cur = outputs - if not skip_blocks: cached_residual = residual_x_cur - latents = latents_returned - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents - - else: - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + else: is_cache_step = step_is_cache[step] if is_cache_step: @@ -351,5 +324,6 @@ def run_inference_2_1( guidance_scale=guidance_scale, ) - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index baaf1c2cc..b98aa2961 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -154,12 +154,13 @@ def __call__( magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, ): + config = getattr(self, "config", None) if magcache_thresh is None: - magcache_thresh = getattr(self.config, "magcache_thresh", 0.04) + magcache_thresh = getattr(config, "magcache_thresh", 0.04) if magcache_K is None: - magcache_K = getattr(self.config, "magcache_K", 2) + magcache_K = getattr(config, "magcache_K", 2) if retention_ratio is None: - retention_ratio = getattr(self.config, "retention_ratio", 0.2) + retention_ratio = getattr(config, "retention_ratio", 0.2) height = height or self.config.height width = width or self.config.width @@ -291,17 +292,11 @@ def run_inference_2_1_i2v( do_cfg = guidance_scale > 1.0 if use_magcache and do_cfg: - ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - cached_residual, - skip_warmup, - mag_ratios, - ) = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) + magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) + accumulated_state = magcache_init[:6] + cached_residual = magcache_init[6] + skip_warmup = magcache_init[7] + mag_ratios = magcache_init[8] if do_cfg: prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) @@ -317,25 +312,9 @@ def run_inference_2_1_i2v( skip_blocks = False if use_magcache and do_cfg: - accumulated_state = ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - ) skip_blocks, accumulated_state = magcache_step( step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup ) - ( - accumulated_ratio_cond, - accumulated_ratio_uncond, - accumulated_err_cond, - accumulated_err_uncond, - accumulated_steps_cond, - accumulated_steps_uncond, - ) = accumulated_state latents_input = latents if do_cfg: