diff --git a/.gitignore b/.gitignore index bd4a64b8..d0f0ea7b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__/ *$py.class # C extensions *.so +Gemini.md # tests and logs tests/fixtures/cached_*_text.txt diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py new file mode 100644 index 00000000..a4d12df6 --- /dev/null +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py @@ -0,0 +1,882 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Tuple, Optional, Dict, Union, Any +import contextlib +import math +import jax +import jax.numpy as jnp +from flax import nnx +from .... import common_types +from ...modeling_flax_utils import FlaxModelMixin +from ....configuration_utils import ConfigMixin, register_to_config +from ...normalization_flax import FP32LayerNorm +from ...gradient_checkpoint import GradientCheckpointType +from .transformer_wan import ( + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) + +BlockSizes = common_types.BlockSizes + +WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = { + "4": 512, + "8": 512, + "16": 512, + "32": 512, + "64": 256, + "128": 128, + "256": 64, + "512": 32, + "1024": 16, +} + + +class FlaxFusedLeakyReLU(nnx.Module): + """ + Fused LeakyRelu with scale factor and channel-wise bias. + """ + + def __init__( + self, + rngs: nnx.Rngs, + negative_slope: float = 0.2, + scale: float = 2**0.5, + bias_channels: Optional[int] = None, + dtype: jnp.dtype = jnp.float32, + ): + self.negative_slope = negative_slope + self.scale = scale + self.channels = bias_channels + self.dtype = dtype + + if self.channels is not None: + self.bias = nnx.Param(jnp.zeros((self.channels,), dtype=self.dtype)) + else: + self.bias = None + + def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array: + if self.bias is not None: + # Expand self.bias to have all singleton dims except at channel_dim + expanded_shape = [1] * x.ndim + expanded_shape[channel_dim] = self.channels + bias = jnp.reshape(self.bias, expanded_shape) + x = x + bias + x = jax.nn.leaky_relu(x, self.negative_slope) * self.scale + return x + + +class FlaxMotionConv2d(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + blur_kernel: Optional[Tuple[int, ...]] = None, + blur_upsample_factor: int = 1, + use_activation: bool = True, + dtype: jnp.dtype = jnp.float32, + ): + self.use_activation = use_activation + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding_size = padding + self.dtype = dtype + + self.blur = False + if blur_kernel is not None: + p = (len(blur_kernel) - stride) + (kernel_size - 1) + self.blur_padding = ((p + 1) // 2, p // 2) + + kernel = jnp.array(blur_kernel, dtype=jnp.float32) + if kernel.ndim == 1: + kernel = jnp.expand_dims(kernel, 0) * jnp.expand_dims(kernel, 1) + kernel = kernel / kernel.sum() + + if blur_upsample_factor > 1: + kernel = kernel * (blur_upsample_factor**2) + + self.blur_kernel = jnp.array(kernel) + self.blur = True + else: + self.blur_kernel = None + + key = rngs.params() + # Shape: (out_channels, in_channels, kernel, kernel) mapping PyTorch 'OIHW' + self.weight = nnx.Param(jax.random.normal(key, (out_channels, in_channels, kernel_size, kernel_size), dtype=dtype)) + self.scale = 1.0 / math.sqrt(in_channels * kernel_size**2) + + if bias and not self.use_activation: + self.bias = nnx.Param(jnp.zeros((out_channels,), dtype=dtype)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FlaxFusedLeakyReLU(rngs=rngs, bias_channels=out_channels, dtype=dtype) + else: + self.act_fn = None + + def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array: + # 1. Blur Pass (Depthwise) + if self.blur: + expanded_kernel = jnp.expand_dims(jnp.expand_dims(self.blur_kernel, 0), 0) + expanded_kernel = jnp.broadcast_to( + expanded_kernel, + ( + self.in_channels, + 1, + expanded_kernel.shape[2], + expanded_kernel.shape[3], + ), + ) + + pad_h, pad_w = self.blur_padding + x = jax.lax.conv_general_dilated( + x, + expanded_kernel, + window_strides=(1, 1), + padding=[(pad_h, pad_h), (pad_w, pad_w)], # Corrected Symmetric Padding + dimension_numbers=("NCHW", "OIHW", "NCHW"), + feature_group_count=self.in_channels, + ) + + # 2. Main Convolution Pass + conv_weight = self.weight * self.scale + x = jax.lax.conv_general_dilated( + x, + conv_weight, + window_strides=(self.stride, self.stride), + padding=[ + (self.padding_size, self.padding_size), + (self.padding_size, self.padding_size), + ], + dimension_numbers=("NCHW", "OIHW", "NCHW"), + ) + + # 3. Bias and Activation + if self.bias is not None: + b = jnp.reshape(self.bias, (1, self.out_channels, 1, 1)) + x = x + b + + if self.use_activation: + x = self.act_fn(x, channel_dim=channel_dim) + + return x + + +class FlaxMotionLinear(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_dim: int, + out_dim: int, + bias: bool = True, + use_activation: bool = False, + dtype: jnp.dtype = jnp.float32, + ): + self.use_activation = use_activation + self.in_dim = in_dim + self.out_dim = out_dim + self.dtype = dtype + + key = rngs.params() + self.weight = nnx.Param(jax.random.normal(key, (out_dim, in_dim), dtype=dtype)) + self.scale = 1.0 / math.sqrt(in_dim) + + if bias and not self.use_activation: + self.bias = nnx.Param(jnp.zeros((out_dim,), dtype=dtype)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FlaxFusedLeakyReLU(rngs=rngs, bias_channels=out_dim, dtype=dtype) + else: + self.act_fn = None + + def __call__(self, inputs: jax.Array, channel_dim: int = 1) -> jax.Array: + # Transpose to (in_dim, out_dim) and apply scale + w = self.weight.T * self.scale + + out = inputs @ w + + if self.bias is not None: + out = out + self.bias + + if self.use_activation: + out = self.act_fn(out, channel_dim=channel_dim) + + return out + + +class FlaxMotionEncoderResBlock(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + kernel_size_skip: int = 1, + blur_kernel: Tuple[int, ...] = (1, 3, 3, 1), + downsample_factor: int = 2, + dtype: jnp.dtype = jnp.float32, + ): + self.downsample_factor = downsample_factor + self.dtype = dtype + + # 3 X 3 Conv + fused leaky ReLU + self.conv1 = FlaxMotionConv2d( + rngs, + in_channels, + in_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + use_activation=True, + dtype=dtype, + ) + + # 3 X 3 Conv + downsample 2x + fused leaky ReLU + self.conv2 = FlaxMotionConv2d( + rngs, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=self.downsample_factor, + padding=0, + blur_kernel=blur_kernel, + use_activation=True, + dtype=dtype, + ) + + # 1 X 1 Conv + downsample 2x in skip connection + self.conv_skip = FlaxMotionConv2d( + rngs, + in_channels, + out_channels, + kernel_size=kernel_size_skip, + stride=self.downsample_factor, + padding=0, + bias=False, + blur_kernel=blur_kernel, + use_activation=False, + dtype=dtype, + ) + + def __call__(self, x: jax.Array, channel_dim: int = 1) -> jax.Array: + x_out = self.conv1(x, channel_dim=channel_dim) + x_out = self.conv2(x_out, channel_dim=channel_dim) + + x_skip = self.conv_skip(x, channel_dim=channel_dim) + + x_out = (x_out + x_skip) / math.sqrt(2.0) + return x_out + + +class FlaxWanAnimateMotionEncoder(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + size: int = 512, + style_dim: int = 512, + motion_dim: int = 20, + out_dim: int = 512, + motion_blocks: int = 5, + channels: Optional[Dict[str, int]] = None, + dtype: jnp.dtype = jnp.float32, + ): + self.size = size + self.dtype = dtype + + if channels is None: + channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES + + self.conv_in = FlaxMotionConv2d(rngs, 3, channels[str(size)], 1, use_activation=True, dtype=dtype) + + res_blocks = [] + in_channels = channels[str(size)] + log_size = int(math.log(size, 2)) + for i in range(log_size, 2, -1): + out_channels = channels[str(2 ** (i - 1))] + res_blocks.append(FlaxMotionEncoderResBlock(rngs, in_channels, out_channels, dtype=dtype)) + in_channels = out_channels + self.res_blocks = nnx.List(res_blocks) + + self.conv_out = FlaxMotionConv2d( + rngs, + in_channels, + style_dim, + 4, + padding=0, + bias=False, + use_activation=False, + dtype=dtype, + ) + + linears = [] + for _ in range(motion_blocks - 1): + linears.append(FlaxMotionLinear(rngs, style_dim, style_dim, dtype=dtype)) + + linears.append(FlaxMotionLinear(rngs, style_dim, motion_dim, dtype=dtype)) + self.motion_network = nnx.List(linears) + + key = rngs.params() + self.motion_synthesis_weight = nnx.Param(jax.random.normal(key, (out_dim, motion_dim), dtype=dtype)) + + def __call__(self, face_image: jax.Array, channel_dim: int = 1) -> jax.Array: + if face_image.shape[-2] != self.size or face_image.shape[-1] != self.size: + raise ValueError(f"Expected {self.size} got {face_image.shape[-1]}") + + x = self.conv_in(face_image, channel_dim=channel_dim) + for block in self.res_blocks: + x = block(x, channel_dim=channel_dim) + x = self.conv_out(x, channel_dim=channel_dim) + + motion_feat = jnp.squeeze(x, axis=(-1, -2)) + + for linear_layer in self.motion_network: + motion_feat = linear_layer(motion_feat, channel_dim=channel_dim) + + weight = self.motion_synthesis_weight[...] + 1e-8 + + original_dtype = motion_feat.dtype + motion_feat_fp32 = motion_feat.astype(jnp.float32) + weight_fp32 = weight.astype(jnp.float32) + + Q, _ = jnp.linalg.qr(weight_fp32) + + motion_vec = jnp.matmul(motion_feat_fp32, jnp.transpose(Q, (1, 0))) + + return motion_vec.astype(original_dtype) + + +class FlaxWanAnimateFaceEncoder(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + in_dim: int, + out_dim: int, + hidden_dim: int = 1024, + num_heads: int = 4, + kernel_size: int = 3, + eps: float = 1e-6, + pad_mode: str = "edge", + dtype: jnp.dtype = jnp.float32, + ): + self.num_heads = num_heads + self.kernel_size = kernel_size + self.pad_mode = pad_mode + self.out_dim = out_dim + self.dtype = dtype + + self.act = jax.nn.silu + + # Added explicit padding="VALID" to exactly mirror PyTorch's padding=0 default + self.conv1_local = nnx.Conv( + in_dim, + hidden_dim * num_heads, + kernel_size=(kernel_size,), + strides=(1,), + padding="VALID", + rngs=rngs, + dtype=dtype, + ) + self.conv2 = nnx.Conv( + hidden_dim, + hidden_dim, + kernel_size=(kernel_size,), + strides=(2,), + padding="VALID", + rngs=rngs, + dtype=dtype, + ) + self.conv3 = nnx.Conv( + hidden_dim, + hidden_dim, + kernel_size=(kernel_size,), + strides=(2,), + padding="VALID", + rngs=rngs, + dtype=dtype, + ) + + self.norm1 = nnx.LayerNorm( + hidden_dim, + epsilon=eps, + use_bias=False, + use_scale=False, + rngs=rngs, + dtype=dtype, + ) + self.norm2 = nnx.LayerNorm( + hidden_dim, + epsilon=eps, + use_bias=False, + use_scale=False, + rngs=rngs, + dtype=dtype, + ) + self.norm3 = nnx.LayerNorm( + hidden_dim, + epsilon=eps, + use_bias=False, + use_scale=False, + rngs=rngs, + dtype=dtype, + ) + + self.out_proj = nnx.Linear(hidden_dim, out_dim, rngs=rngs, dtype=dtype) + + self.padding_tokens = nnx.Param(jnp.zeros((1, 1, 1, out_dim), dtype=dtype)) + + def __call__(self, x: jax.Array) -> jax.Array: + batch_size = x.shape[0] + + # Local attention via causal convolution + x = jnp.pad(x, ((0, 0), (self.kernel_size - 1, 0), (0, 0)), mode=self.pad_mode) + x = self.conv1_local(x) + + x = jnp.reshape(x, (batch_size, x.shape[1], self.num_heads, -1)) + x = jnp.transpose(x, (0, 2, 1, 3)) + x = jnp.reshape(x, (batch_size * self.num_heads, x.shape[2], x.shape[3])) + + x = self.norm1(x) + x = self.act(x) + + x = jnp.pad(x, ((0, 0), (self.kernel_size - 1, 0), (0, 0)), mode=self.pad_mode) + x = self.conv2(x) + x = self.norm2(x) + x = self.act(x) + + x = jnp.pad(x, ((0, 0), (self.kernel_size - 1, 0), (0, 0)), mode=self.pad_mode) + x = self.conv3(x) + x = self.norm3(x) + x = self.act(x) + + x = self.out_proj(x) + + x = jnp.reshape(x, (batch_size, self.num_heads, x.shape[1], x.shape[2])) + x = jnp.transpose(x, (0, 2, 1, 3)) + + padding = jnp.broadcast_to(self.padding_tokens[...], (batch_size, x.shape[1], 1, self.out_dim)) + x = jnp.concatenate([x, padding], axis=2) + + return x + + +class FlaxWanAnimateFaceBlockCrossAttention(nnx.Module): + + def __init__( + self, + rngs: nnx.Rngs, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-6, + cross_attention_dim_head: Optional[int] = None, + use_bias: bool = True, + dtype: jnp.dtype = jnp.float32, + ): + self.heads = heads + self.inner_dim = dim_head * heads + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + self.dtype = dtype + + self.pre_norm_q = nnx.LayerNorm(dim, epsilon=eps, use_bias=False, use_scale=False, rngs=rngs, dtype=dtype) + self.pre_norm_kv = nnx.LayerNorm(dim, epsilon=eps, use_bias=False, use_scale=False, rngs=rngs, dtype=dtype) + + self.to_q = nnx.Linear(dim, self.inner_dim, use_bias=use_bias, rngs=rngs, dtype=dtype) + self.to_k = nnx.Linear(dim, self.kv_inner_dim, use_bias=use_bias, rngs=rngs, dtype=dtype) + self.to_v = nnx.Linear(dim, self.kv_inner_dim, use_bias=use_bias, rngs=rngs, dtype=dtype) + + self.to_out = nnx.Linear(self.inner_dim, dim, use_bias=use_bias, rngs=rngs, dtype=dtype) + + self.norm_q = nnx.RMSNorm(dim_head, epsilon=eps, use_scale=True, rngs=rngs, dtype=dtype) + self.norm_k = nnx.RMSNorm(dim_head, epsilon=eps, use_scale=True, rngs=rngs, dtype=dtype) + + def __call__( + self, + hidden_states: jax.Array, + encoder_hidden_states: jax.Array, + attention_mask: Optional[jax.Array] = None, + ) -> jax.Array: + hidden_states = self.pre_norm_q(hidden_states) + encoder_hidden_states = self.pre_norm_kv(encoder_hidden_states) + + B, T, N, C = encoder_hidden_states.shape + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # Reshape to extract heads + query = jnp.reshape(query, (query.shape[0], query.shape[1], self.heads, -1)) + key = jnp.reshape(key, (B, T, N, self.heads, -1)) + value = jnp.reshape(value, (B, T, N, self.heads, -1)) + + query = self.norm_q(query) + key = self.norm_k(key) + + query_S = query.shape[1] + + # Prepare for attention by folding Time into the Batch dimension + query = jnp.reshape(query, (B * T, query_S // T, self.heads, -1)) + key = jnp.reshape(key, (B * T, N, self.heads, -1)) + value = jnp.reshape(value, (B * T, N, self.heads, -1)) + + attn_output = jax.nn.dot_product_attention(query, key, value) + + # Collapse Time, Seq Length, and Heads straight back to (Batch, Total Sequence, Dim) + attn_output = jnp.reshape(attn_output, (B, query_S, -1)) + + hidden_states = self.to_out(attn_output) + + if attention_mask is not None: + attention_mask = jnp.reshape(attention_mask, (attention_mask.shape[0], -1)) + hidden_states = hidden_states * jnp.expand_dims(attention_mask, axis=-1) + + return hidden_states + + +class NNXWanAnimateTransformer3DModel(nnx.Module, FlaxModelMixin, ConfigMixin): + + @register_to_config + def __init__( + self, + rngs: nnx.Rngs, + model_type="t2v", + patch_size: Tuple[int, int, int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 36, + latent_channels: int = 16, + out_channels: Optional[int] = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + dropout: float = 0.0, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = 1280, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + image_seq_len: Optional[int] = None, + flash_min_seq_length: int = 4096, + flash_block_sizes: BlockSizes = None, + mesh: jax.sharding.Mesh = None, + dtype: jnp.dtype = jnp.float32, + weights_dtype: jnp.dtype = jnp.float32, + precision: jax.lax.Precision = None, + attention: str = "dot_product", + remat_policy: str = "None", + names_which_can_be_saved: list = [], + names_which_can_be_offloaded: list = [], + mask_padding_tokens: bool = True, + scan_layers: bool = True, + enable_jax_named_scopes: bool = False, + motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, + motion_encoder_size: int = 512, + motion_style_dim: int = 512, + motion_dim: int = 20, + motion_encoder_dim: int = 512, + face_encoder_hidden_dim: int = 1024, + face_encoder_num_heads: int = 4, + inject_face_latents_blocks: int = 5, + motion_encoder_batch_size: int = 8, + ): + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or latent_channels + + self.num_layers = num_layers + self.scan_layers = scan_layers + self.enable_jax_named_scopes = enable_jax_named_scopes + self.patch_size = patch_size + self.inject_face_latents_blocks = inject_face_latents_blocks + self.motion_encoder_batch_size = motion_encoder_batch_size + self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) + + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nnx.Conv( + in_channels, + inner_dim, + kernel_size=patch_size, + strides=patch_size, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + ) + self.pose_patch_embedding = nnx.Conv( + latent_channels, + inner_dim, + kernel_size=patch_size, + strides=patch_size, + rngs=rngs, + dtype=dtype, + param_dtype=weights_dtype, + ) + + self.condition_embedder = WanTimeTextImageEmbedding( + rngs=rngs, + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + flash_min_seq_length=flash_min_seq_length, + dtype=dtype, + weights_dtype=weights_dtype, + ) + + self.motion_encoder = FlaxWanAnimateMotionEncoder( + rngs=rngs, + size=motion_encoder_size, + style_dim=motion_style_dim, + motion_dim=motion_dim, + out_dim=motion_encoder_dim, + channels=motion_encoder_channel_sizes, + dtype=dtype, + ) + self.face_encoder = FlaxWanAnimateFaceEncoder( + rngs=rngs, + in_dim=motion_encoder_dim, + out_dim=inner_dim, + hidden_dim=face_encoder_hidden_dim, + num_heads=face_encoder_num_heads, + dtype=dtype, + ) + + blocks = [] + for _ in range(num_layers): + block = WanTransformerBlock( + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + image_seq_len=image_seq_len, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention, + enable_jax_named_scopes=enable_jax_named_scopes, + ) + blocks.append(block) + self.blocks = nnx.List(blocks) + + face_adapters = [] + num_face_adapters = num_layers // inject_face_latents_blocks + for _ in range(num_face_adapters): + fa = FlaxWanAnimateFaceBlockCrossAttention( + rngs=rngs, + dim=inner_dim, + heads=num_attention_heads, + dim_head=inner_dim // num_attention_heads, + eps=eps, + cross_attention_dim_head=inner_dim // num_attention_heads, + dtype=dtype, + ) + face_adapters.append(fa) + self.face_adapter = nnx.List(face_adapters) + + self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) + self.proj_out = nnx.Linear( + rngs=rngs, + in_features=inner_dim, + out_features=out_channels * math.prod(patch_size), + dtype=dtype, + param_dtype=weights_dtype, + ) + key = rngs.params() + self.scale_shift_table = nnx.Param(jax.random.normal(key, (1, 2, inner_dim), dtype=dtype) / inner_dim**0.5) + + def conditional_named_scope(self, name: str): + return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + + @jax.named_scope("WanAnimateTransformer3DModel") + def __call__( + self, + hidden_states: jax.Array, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + pose_hidden_states: Optional[jax.Array] = None, + face_pixel_values: Optional[jax.Array] = None, + motion_encode_batch_size: Optional[int] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + deterministic: bool = True, + rngs: nnx.Rngs = None, + ) -> Union[jax.Array, Dict[str, jax.Array]]: + if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]: + raise ValueError( + f"Pose frames + 1 ({pose_hidden_states.shape[2]} + 1) must equal hidden_states frames ({hidden_states.shape[2]})" + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1 & 2. Rotary Position & Patch Embedding + hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) + rotary_emb = self.rope(hidden_states) + hidden_states = self.patch_embedding(hidden_states) + + pose_hidden_states = jnp.transpose(pose_hidden_states, (0, 2, 3, 4, 1)) + pose_hidden_states = self.pose_patch_embedding(pose_hidden_states) + pose_pad = jnp.zeros( + ( + batch_size, + 1, + pose_hidden_states.shape[2], + pose_hidden_states.shape[3], + pose_hidden_states.shape[4], + ), + dtype=hidden_states.dtype, + ) + pose_pad = jnp.concatenate([pose_pad, pose_hidden_states], axis=1) + hidden_states = hidden_states + pose_pad + + hidden_states = jnp.reshape(hidden_states, (batch_size, -1, hidden_states.shape[-1])) + + # 3. Condition Embeddings + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + _, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + timestep_proj = timestep_proj.reshape(batch_size, 6, -1) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + + # 4. Batched Face & Motion Encoding + _, face_channels, num_face_frames, face_height, face_width = face_pixel_values.shape + + # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) + face_pixel_values = jnp.transpose(face_pixel_values, (0, 2, 1, 3, 4)) + face_pixel_values = jnp.reshape(face_pixel_values, (-1, face_channels, face_height, face_width)) + + total_face_frames = face_pixel_values.shape[0] + motion_encode_batch_size = motion_encode_batch_size or self.motion_encoder_batch_size + + # Pad sequence if it doesn't divide evenly by encode_bs + pad_len = (motion_encode_batch_size - (total_face_frames % motion_encode_batch_size)) % motion_encode_batch_size + if pad_len > 0: + pad_tensor = jnp.zeros( + (pad_len, face_channels, face_height, face_width), + dtype=face_pixel_values.dtype, + ) + face_pixel_values = jnp.concatenate([face_pixel_values, pad_tensor], axis=0) + + # Reshape into chunks for scan + num_chunks = face_pixel_values.shape[0] // motion_encode_batch_size + face_chunks = jnp.reshape( + face_pixel_values, + ( + num_chunks, + motion_encode_batch_size, + face_channels, + face_height, + face_width, + ), + ) + + # Use jax.lax.scan to iterate over chunks to save memory + def encode_chunk_fn(carry, chunk): + encoded_chunk = self.motion_encoder(chunk) + return carry, encoded_chunk + + _, motion_vec_chunks = jax.lax.scan(encode_chunk_fn, None, face_chunks) + motion_vec = jnp.reshape(motion_vec_chunks, (-1, motion_vec_chunks.shape[-1])) + + # Remove padding if added + if pad_len > 0: + motion_vec = motion_vec[:-pad_len] + + motion_vec = jnp.reshape(motion_vec, (batch_size, num_face_frames, -1)) + + # Apply face encoder + motion_vec = self.face_encoder(motion_vec) + pad_face = jnp.zeros_like(motion_vec[:, :1]) + motion_vec = jnp.concatenate([pad_face, motion_vec], axis=1) + + # 5. Transformer Blocks + for block_idx, block in enumerate(self.blocks): + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + deterministic, + rngs, + ) + + # Face adapter integration: apply after every 5th block (0, 5, 10, 15, ...) + if motion_vec is not None and block_idx % self.inject_face_latents_blocks == 0: + face_adapter_block_idx = block_idx // self.inject_face_latents_blocks + face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec) + hidden_states = hidden_states + face_adapter_output + + # 6. Output Norm & Projection + shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) + hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype) + hidden_states = self.proj_out(hidden_states) + + hidden_states = jnp.reshape( + hidden_states, + ( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ), + ) + hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) + hidden_states = jnp.reshape(hidden_states, (batch_size, -1, num_frames, height, width)) + + if not return_dict: + return (hidden_states,) + return {"sample": hidden_states} diff --git a/src/maxdiffusion/tests/wan_animate/test_transformer_wan_animate.py b/src/maxdiffusion/tests/wan_animate/test_transformer_wan_animate.py new file mode 100644 index 00000000..73fb2095 --- /dev/null +++ b/src/maxdiffusion/tests/wan_animate/test_transformer_wan_animate.py @@ -0,0 +1,692 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import torch +import numpy as np +import jax.numpy as jnp +from flax import nnx + +# Import from the codebase +# Make sure to set Python path or import correctly +from maxdiffusion.models.wan.transformers.transformer_wan_animate import ( + FlaxMotionConv2d, + FlaxMotionLinear, + FlaxMotionEncoderResBlock, + FlaxWanAnimateMotionEncoder, + FlaxWanAnimateFaceEncoder, + FlaxWanAnimateFaceBlockCrossAttention, + NNXWanAnimateTransformer3DModel, + FlaxFusedLeakyReLU, +) +from maxdiffusion import pyconfig +from maxdiffusion.max_utils import create_device_mesh, get_flash_block_sizes +from jax.sharding import Mesh + + +def transfer_conv_weights(pt_conv, jax_conv): + if hasattr(jax_conv, "weight"): + jax_conv.weight[...] = jnp.array(pt_conv.weight.detach().numpy()) + if pt_conv.bias is not None: + if hasattr(jax_conv, "use_activation") and jax_conv.use_activation: + jax_conv.act_fn.bias[...] = jnp.array(pt_conv.bias.detach().numpy()) + else: + if jax_conv.bias is not None: + jax_conv.bias[...] = jnp.array(pt_conv.bias.detach().numpy()) + elif hasattr(jax_conv, "kernel"): + if pt_conv.weight.ndim == 5: + jax_conv.kernel[...] = jnp.array(pt_conv.weight.detach().numpy().transpose(2, 3, 4, 1, 0)) + else: + jax_conv.kernel[...] = jnp.array(pt_conv.weight.detach().numpy().transpose(2, 3, 1, 0)) + if pt_conv.bias is not None: + jax_conv.bias[...] = jnp.array(pt_conv.bias.detach().numpy()) + + +def transfer_linear_weights(pt_linear, jax_linear): + if hasattr(jax_linear, "weight"): + jax_linear.weight[...] = jnp.array(pt_linear.weight.detach().numpy()) + if pt_linear.bias is not None: + if hasattr(jax_linear, "use_activation") and jax_linear.use_activation: + jax_linear.act_fn.bias[...] = jnp.array(pt_linear.bias.detach().numpy()) + else: + if jax_linear.bias is not None: + jax_linear.bias[...] = jnp.array(pt_linear.bias.detach().numpy()) + elif hasattr(jax_linear, "kernel"): + jax_linear.kernel[...] = jnp.array(pt_linear.weight.detach().numpy().T) + if pt_linear.bias is not None: + jax_linear.bias[...] = jnp.array(pt_linear.bias.detach().numpy()) + + +def transfer_transformer_weights(pt_model, jax_model): + # Patch Embeddings + transfer_conv_weights(pt_model.patch_embedding, jax_model.patch_embedding) + transfer_conv_weights(pt_model.pose_patch_embedding, jax_model.pose_patch_embedding) + + # Condition Embeddings + transfer_linear_weights( + pt_model.condition_embedder.time_embedder.linear_1, + jax_model.condition_embedder.time_embedder.linear_1, + ) + transfer_linear_weights( + pt_model.condition_embedder.time_embedder.linear_2, + jax_model.condition_embedder.time_embedder.linear_2, + ) + transfer_linear_weights(pt_model.condition_embedder.time_proj, jax_model.condition_embedder.time_proj) + transfer_linear_weights( + pt_model.condition_embedder.text_embedder.linear_1, + jax_model.condition_embedder.text_embedder.linear_1, + ) + transfer_linear_weights( + pt_model.condition_embedder.text_embedder.linear_2, + jax_model.condition_embedder.text_embedder.linear_2, + ) + + if pt_model.condition_embedder.image_embedder is not None: + jax_model.condition_embedder.image_embedder.norm1.layer_norm.scale[...] = jnp.array( + pt_model.condition_embedder.image_embedder.norm1.weight.detach().numpy() + ) + jax_model.condition_embedder.image_embedder.norm1.layer_norm.bias[...] = jnp.array( + pt_model.condition_embedder.image_embedder.norm1.bias.detach().numpy() + ) + + transfer_linear_weights( + pt_model.condition_embedder.image_embedder.ff.net[0].proj, + jax_model.condition_embedder.image_embedder.ff.net_0, + ) + transfer_linear_weights( + pt_model.condition_embedder.image_embedder.ff.net[2], + jax_model.condition_embedder.image_embedder.ff.net_2, + ) + + jax_model.condition_embedder.image_embedder.norm2.layer_norm.scale[...] = jnp.array( + pt_model.condition_embedder.image_embedder.norm2.weight.detach().numpy() + ) + jax_model.condition_embedder.image_embedder.norm2.layer_norm.bias[...] = jnp.array( + pt_model.condition_embedder.image_embedder.norm2.bias.detach().numpy() + ) + + # Motion Encoder + transfer_conv_weights(pt_model.motion_encoder.conv_in, jax_model.motion_encoder.conv_in) + for i in range(len(pt_model.motion_encoder.res_blocks)): + transfer_conv_weights( + pt_model.motion_encoder.res_blocks[i].conv1, + jax_model.motion_encoder.res_blocks[i].conv1, + ) + transfer_conv_weights( + pt_model.motion_encoder.res_blocks[i].conv2, + jax_model.motion_encoder.res_blocks[i].conv2, + ) + if pt_model.motion_encoder.res_blocks[i].conv_skip is not None: + transfer_conv_weights( + pt_model.motion_encoder.res_blocks[i].conv_skip, + jax_model.motion_encoder.res_blocks[i].conv_skip, + ) + + transfer_conv_weights(pt_model.motion_encoder.conv_out, jax_model.motion_encoder.conv_out) + + for i in range(len(pt_model.motion_encoder.motion_network)): + transfer_linear_weights( + pt_model.motion_encoder.motion_network[i], + jax_model.motion_encoder.motion_network[i], + ) + + jax_model.motion_encoder.motion_synthesis_weight[...] = jnp.array( + pt_model.motion_encoder.motion_synthesis_weight.detach().numpy() + ) + + # Face Encoder + jax_model.face_encoder.conv1_local.kernel[...] = jnp.array( + pt_model.face_encoder.conv1_local.weight.detach().numpy().transpose(2, 1, 0) + ) + jax_model.face_encoder.conv1_local.bias[...] = jnp.array(pt_model.face_encoder.conv1_local.bias.detach().numpy()) + + jax_model.face_encoder.conv2.kernel[...] = jnp.array( + pt_model.face_encoder.conv2.weight.detach().numpy().transpose(2, 1, 0) + ) + jax_model.face_encoder.conv2.bias[...] = jnp.array(pt_model.face_encoder.conv2.bias.detach().numpy()) + + jax_model.face_encoder.conv3.kernel[...] = jnp.array( + pt_model.face_encoder.conv3.weight.detach().numpy().transpose(2, 1, 0) + ) + jax_model.face_encoder.conv3.bias[...] = jnp.array(pt_model.face_encoder.conv3.bias.detach().numpy()) + + transfer_linear_weights(pt_model.face_encoder.out_proj, jax_model.face_encoder.out_proj) + + jax_model.face_encoder.padding_tokens[...] = jnp.array(pt_model.face_encoder.padding_tokens.detach().numpy()) + + # Blocks + for i in range(len(pt_model.blocks)): + pt_block = pt_model.blocks[i] + jax_block = jax_model.blocks[i] + + # Self Attention + transfer_linear_weights(pt_block.attn1.to_q, jax_block.attn1.query) + transfer_linear_weights(pt_block.attn1.to_k, jax_block.attn1.key) + transfer_linear_weights(pt_block.attn1.to_v, jax_block.attn1.value) + transfer_linear_weights(pt_block.attn1.to_out[0], jax_block.attn1.proj_attn) + + jax_block.attn1.norm_q.scale[...] = jnp.array(pt_block.attn1.norm_q.weight.detach().numpy()) + jax_block.attn1.norm_k.scale[...] = jnp.array(pt_block.attn1.norm_k.weight.detach().numpy()) + + # Cross Attention + if hasattr(pt_block, "attn2"): + transfer_linear_weights(pt_block.attn2.to_q, jax_block.attn2.query) + transfer_linear_weights(pt_block.attn2.to_k, jax_block.attn2.key) + transfer_linear_weights(pt_block.attn2.to_v, jax_block.attn2.value) + transfer_linear_weights(pt_block.attn2.to_out[0], jax_block.attn2.proj_attn) + + jax_block.attn2.norm_q.scale[...] = jnp.array(pt_block.attn2.norm_q.weight.detach().numpy()) + jax_block.attn2.norm_k.scale[...] = jnp.array(pt_block.attn2.norm_k.weight.detach().numpy()) + + jax_block.norm2.layer_norm.scale[...] = jnp.array(pt_block.norm2.weight.detach().numpy()) + jax_block.norm2.layer_norm.bias[...] = jnp.array(pt_block.norm2.bias.detach().numpy()) + + # FFN + transfer_linear_weights(pt_block.ffn.net[0].proj, jax_block.ffn.act_fn.proj) + transfer_linear_weights(pt_block.ffn.net[2], jax_block.ffn.proj_out) + + jax_block.adaln_scale_shift_table[...] = jnp.array(pt_block.scale_shift_table.detach().numpy()) + + # Face Adapter + for i in range(len(pt_model.face_adapter)): + transfer_linear_weights(pt_model.face_adapter[i].to_q, jax_model.face_adapter[i].to_q) + transfer_linear_weights(pt_model.face_adapter[i].to_k, jax_model.face_adapter[i].to_k) + transfer_linear_weights(pt_model.face_adapter[i].to_v, jax_model.face_adapter[i].to_v) + transfer_linear_weights(pt_model.face_adapter[i].to_out, jax_model.face_adapter[i].to_out) + + jax_model.face_adapter[i].norm_q.scale[...] = jnp.array(pt_model.face_adapter[i].norm_q.weight.detach().numpy()) + jax_model.face_adapter[i].norm_k.scale[...] = jnp.array(pt_model.face_adapter[i].norm_k.weight.detach().numpy()) + + # Final Norm & Proj + # jax_model.norm_out.scale[...] = jnp.array(pt_model.norm_out.weight.detach().numpy()) + # jax_model.norm_out.bias[...] = jnp.array(pt_model.norm_out.bias.detach().numpy()) + + transfer_linear_weights(pt_model.proj_out, jax_model.proj_out) + + jax_model.scale_shift_table[...] = jnp.array(pt_model.scale_shift_table.detach().numpy()) + + +class TestWanAnimateTransformer: + + def test_motion_conv_equivalence(self): + B, C_in, H, W = 2, 4, 16, 16 + C_out = 8 + k_size = 3 + stride = 2 + pad = 1 + blur_k = (1, 3, 3, 1) + + x_np = np.random.randn(B, C_in, H, W).astype(np.float32) + x_pt = torch.from_numpy(x_np) + x_jax = jnp.array(x_np) + + from diffusers.models.transformers.transformer_wan_animate import MotionConv2d + + pt_model = MotionConv2d( + in_channels=C_in, + out_channels=C_out, + kernel_size=k_size, + stride=stride, + padding=pad, + blur_kernel=blur_k, + use_activation=False, + ) + + rngs = nnx.Rngs(0) + jax_model = FlaxMotionConv2d( + rngs=rngs, + in_channels=C_in, + out_channels=C_out, + kernel_size=k_size, + stride=stride, + padding=pad, + blur_kernel=blur_k, + use_activation=False, + ) + + jax_model.weight[...] = jnp.array(pt_model.weight.detach().numpy()) + if jax_model.bias is not None: + jax_model.bias[...] = jnp.array(pt_model.bias.detach().numpy()) + + pt_model.eval() + with torch.no_grad(): + out_pt = pt_model(x_pt) + + out_jax = jax_model(x_jax) + + np_out_pt = out_pt.numpy() + np_out_jax = np.array(out_jax) + + max_diff = np.max(np.abs(np_out_pt - np_out_jax)) + print(f"Max absolute difference: {max_diff:.8f}") + + np.testing.assert_allclose( + np_out_pt, np_out_jax, rtol=1e-3, atol=5e-3, err_msg=f"Outputs do not match! max_diff={max_diff}" + ) + + def test_fused_leaky_relu_shape(self): + rngs = nnx.Rngs(0) + x = jnp.ones((2, 4, 16, 16)) + model = FlaxFusedLeakyReLU(rngs=rngs, bias_channels=4) + out = model(x) + np.testing.assert_equal(out.shape, x.shape) + + def test_motion_linear_shape(self): + rngs = nnx.Rngs(0) + x = jnp.ones((2, 4)) + model = FlaxMotionLinear(rngs=rngs, in_dim=4, out_dim=8) + out = model(x) + np.testing.assert_equal(out.shape, (2, 8)) + + def test_motion_encoder_res_block_shape(self): + rngs = nnx.Rngs(0) + x = jnp.ones((2, 4, 16, 16)) + model = FlaxMotionEncoderResBlock(rngs=rngs, in_channels=4, out_channels=8) + out = model(x) + np.testing.assert_equal(out.shape, (2, 8, 8, 8)) + + def test_wan_animate_motion_encoder_shape(self): + rngs = nnx.Rngs(0) + x = jnp.ones((2, 3, 512, 512)) # size size + model = FlaxWanAnimateMotionEncoder(rngs=rngs, size=512, style_dim=512, motion_dim=20, out_dim=512) + out = model(x) + np.testing.assert_equal(out.shape, (2, 512)) + + def test_wan_animate_face_encoder_shape(self): + rngs = nnx.Rngs(0) + x = jnp.ones((2, 10, 512)) # Batch, Time, Dim + model = FlaxWanAnimateFaceEncoder(rngs=rngs, in_dim=512, out_dim=512, num_heads=4) + out = model(x) + np.testing.assert_equal(out.shape, (2, 3, 5, 512)) + + def test_wan_animate_face_block_cross_attention_shape(self): + rngs = nnx.Rngs(0) + hidden_states = jnp.ones((2, 10, 512)) # B, Q_len, Dim + encoder_hidden_states = jnp.ones((2, 1, 5, 512)) # B, T, N, Dim + model = FlaxWanAnimateFaceBlockCrossAttention(rngs=rngs, dim=512, heads=8) + out = model(hidden_states, encoder_hidden_states) + np.testing.assert_equal(out.shape, hidden_states.shape) + + def test_nnx_wan_animate_transformer_3d_model_shape(self): + rngs = nnx.Rngs(0) + pyconfig.initialize( + [ + None, + os.path.join(os.path.dirname(__file__), "..", "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + flash_block_sizes = get_flash_block_sizes(config) + mesh = Mesh(devices_array, config.mesh_axes) + + with mesh: + model = NNXWanAnimateTransformer3DModel( + rngs=rngs, + num_layers=1, + num_attention_heads=4, + attention_head_dim=32, + in_channels=16, + latent_channels=4, + out_channels=16, + image_dim=512, + patch_size=(1, 2, 2), + flash_min_seq_length=1, + scan_layers=False, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + inject_face_latents_blocks=1, + motion_encoder_size=128, + ) + batch_size = 1 + num_frames = 2 + height = 8 + width = 8 + hidden_states = jnp.ones((batch_size, 16, num_frames, height, width)) + timestep = jnp.ones((batch_size,)) + encoder_hidden_states = jnp.ones((batch_size, 10, 4096)) + pose_hidden_states = jnp.ones((batch_size, 4, num_frames - 1, height, width)) + face_pixel_values = jnp.ones((batch_size, 3, num_frames, 128, 128)) + + out = model( + hidden_states, + timestep, + encoder_hidden_states, + pose_hidden_states=pose_hidden_states, + face_pixel_values=face_pixel_values, + return_dict=False, + ) + if isinstance(out, (list, tuple)): + out = out[0] + np.testing.assert_equal(out.shape, (batch_size, 16, num_frames, height, width)) + + def test_nnx_wan_animate_transformer_3d_model_shape_with_face(self): + rngs = nnx.Rngs(0) + pyconfig.initialize( + [ + None, + os.path.join(os.path.dirname(__file__), "..", "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + devices_array = create_device_mesh(config) + flash_block_sizes = get_flash_block_sizes(config) + mesh = Mesh(devices_array, config.mesh_axes) + + with mesh: + model = NNXWanAnimateTransformer3DModel( + rngs=rngs, + num_layers=1, + num_attention_heads=4, + attention_head_dim=32, + in_channels=16, + latent_channels=4, + out_channels=16, + image_dim=512, + patch_size=(1, 2, 2), + flash_min_seq_length=1, + scan_layers=False, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + inject_face_latents_blocks=1, + motion_encoder_size=128, + ) + batch_size = 1 + num_frames = 2 + height = 8 + width = 8 + hidden_states = jnp.ones((batch_size, 16, num_frames, height, width)) + timestep = jnp.ones((batch_size,)) + encoder_hidden_states = jnp.ones((batch_size, 10, 4096)) + pose_hidden_states = jnp.ones((batch_size, 4, num_frames - 1, height, width)) + face_pixel_values = jnp.ones((batch_size, 3, num_frames, 128, 128)) + + out = model( + hidden_states, + timestep, + encoder_hidden_states, + pose_hidden_states=pose_hidden_states, + face_pixel_values=face_pixel_values, + return_dict=False, + ) + if isinstance(out, (list, tuple)): + out = out[0] + np.testing.assert_equal(out.shape, (batch_size, 16, num_frames, height, width)) + + def test_equivalence_motion_encoder(self): + from diffusers.models.transformers.transformer_wan_animate import ( + WanAnimateMotionEncoder, + ) + + test_size = 128 + batch_size = 2 + + pt_model = WanAnimateMotionEncoder(size=test_size).eval() + rngs = nnx.Rngs(0) + jax_model = FlaxWanAnimateMotionEncoder(rngs, size=test_size) + + transfer_conv_weights(pt_model.conv_in, jax_model.conv_in) + + for pt_res, jax_res in zip(pt_model.res_blocks, jax_model.res_blocks): + transfer_conv_weights(pt_res.conv1, jax_res.conv1) + transfer_conv_weights(pt_res.conv2, jax_res.conv2) + if pt_res.conv_skip is not None: + transfer_conv_weights(pt_res.conv_skip, jax_res.conv_skip) + + transfer_conv_weights(pt_model.conv_out, jax_model.conv_out) + + for pt_lin, jax_lin in zip(pt_model.motion_network, jax_model.motion_network): + transfer_linear_weights(pt_lin, jax_lin) + + jax_model.motion_synthesis_weight[...] = jnp.array(pt_model.motion_synthesis_weight.detach().numpy()) + + dummy_input = torch.randn(batch_size, 3, test_size, test_size) + jax_input = jnp.array(dummy_input.numpy()) + + with torch.no_grad(): + pt_out = pt_model(dummy_input) + + jax_out = jax_model(jax_input) + + np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), rtol=1e-3, atol=5e-3) + + def test_equivalence_face_encoder(self): + from diffusers.models.transformers.transformer_wan_animate import ( + WanAnimateFaceEncoder, + ) + + in_dim = 512 + out_dim = 1024 + batch_size = 2 + seq_len = 10 + + pt_model = WanAnimateFaceEncoder(in_dim=in_dim, out_dim=out_dim).eval() + rngs = nnx.Rngs(0) + jax_model = FlaxWanAnimateFaceEncoder(rngs, in_dim=in_dim, out_dim=out_dim) + + # Transfer weights + jax_model.conv1_local.kernel[...] = jnp.array(pt_model.conv1_local.weight.detach().numpy().transpose(2, 1, 0)) + jax_model.conv1_local.bias[...] = jnp.array(pt_model.conv1_local.bias.detach().numpy()) + + jax_model.conv2.kernel[...] = jnp.array(pt_model.conv2.weight.detach().numpy().transpose(2, 1, 0)) + jax_model.conv2.bias[...] = jnp.array(pt_model.conv2.bias.detach().numpy()) + + jax_model.conv3.kernel[...] = jnp.array(pt_model.conv3.weight.detach().numpy().transpose(2, 1, 0)) + jax_model.conv3.bias[...] = jnp.array(pt_model.conv3.bias.detach().numpy()) + + jax_model.out_proj.kernel[...] = jnp.array(pt_model.out_proj.weight.detach().numpy().T) + jax_model.out_proj.bias[...] = jnp.array(pt_model.out_proj.bias.detach().numpy()) + + jax_model.padding_tokens[...] = jnp.array(pt_model.padding_tokens.detach().numpy()) + + dummy_input = torch.randn(batch_size, seq_len, in_dim) + jax_input = jnp.array(dummy_input.numpy()) + + with torch.no_grad(): + pt_out = pt_model(dummy_input) + + jax_out = jax_model(jax_input) + + np.testing.assert_allclose( + pt_out.numpy(), + np.array(jax_out), + rtol=1e-3, + atol=5e-3, # Slightly higher tolerance for convolutions + ) + + def test_equivalence_face_block_cross_attention(self): + from diffusers.models.transformers.transformer_wan_animate import ( + WanAnimateFaceBlockCrossAttention, + ) + + dim = 512 + heads = 8 + dim_head = 64 + batch_size = 2 + seq_len_q = 10 + seq_len_kv_T = 2 + seq_len_kv_N = 5 + + pt_model = WanAnimateFaceBlockCrossAttention( + dim=dim, + heads=heads, + dim_head=dim_head, + cross_attention_dim_head=dim_head, # Ensure cross attention + ).eval() + rngs = nnx.Rngs(0) + jax_model = FlaxWanAnimateFaceBlockCrossAttention( + rngs=rngs, + dim=dim, + heads=heads, + dim_head=dim_head, + cross_attention_dim_head=dim_head, + ) + + # Transfer weights + transfer_linear_weights(pt_model.to_q, jax_model.to_q) + transfer_linear_weights(pt_model.to_k, jax_model.to_k) + transfer_linear_weights(pt_model.to_v, jax_model.to_v) + transfer_linear_weights(pt_model.to_out, jax_model.to_out) + + jax_model.norm_q.scale[...] = jnp.array(pt_model.norm_q.weight.detach().numpy()) + jax_model.norm_k.scale[...] = jnp.array(pt_model.norm_k.weight.detach().numpy()) + + # Inputs + hidden_states_pt = torch.randn(batch_size, seq_len_q, dim) + # encoder_hidden_states shape: B, T, N, C + encoder_hidden_states_pt = torch.randn(batch_size, seq_len_kv_T, seq_len_kv_N, dim) + + hidden_states_jax = jnp.array(hidden_states_pt.numpy()) + encoder_hidden_states_jax = jnp.array(encoder_hidden_states_pt.numpy()) + + with torch.no_grad(): + pt_out = pt_model(hidden_states_pt, encoder_hidden_states_pt) + + jax_out = jax_model(hidden_states_jax, encoder_hidden_states_jax) + + np.testing.assert_allclose(pt_out.numpy(), np.array(jax_out), rtol=1e-3, atol=5e-3) + + def test_equivalence_wan_animate_transformer(self): + from diffusers.models.transformers.transformer_wan_animate import ( + WanAnimateTransformer3DModel, + ) + + config = { + "patch_size": (1, 2, 2), + "num_attention_heads": 4, + "attention_head_dim": 32, + "in_channels": 12, + "latent_channels": 4, + "out_channels": 4, + "text_dim": 64, + "freq_dim": 32, + "ffn_dim": 64, + "num_layers": 1, + "image_dim": 64, + "motion_encoder_size": 128, + "motion_style_dim": 128, + "motion_dim": 20, + "motion_encoder_dim": 128, + "face_encoder_hidden_dim": 128, + "face_encoder_num_heads": 4, + "inject_face_latents_blocks": 1, + } + + pt_model = WanAnimateTransformer3DModel(**config).eval() + rngs = nnx.Rngs(0) + + pyconfig.initialize( + [ + None, + os.path.join(os.path.dirname(__file__), "..", "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + pyconfig_config = pyconfig.config + devices_array = create_device_mesh(pyconfig_config) + flash_block_sizes = get_flash_block_sizes(pyconfig_config) + mesh = Mesh(devices_array, pyconfig_config.mesh_axes) + + with mesh: + jax_model = NNXWanAnimateTransformer3DModel( + rngs=rngs, + patch_size=config["patch_size"], + num_attention_heads=config["num_attention_heads"], + attention_head_dim=config["attention_head_dim"], + in_channels=config["in_channels"], + latent_channels=config["latent_channels"], + out_channels=config["out_channels"], + text_dim=config["text_dim"], + freq_dim=config["freq_dim"], + ffn_dim=config["ffn_dim"], + num_layers=config["num_layers"], + image_dim=config["image_dim"], + motion_encoder_size=config["motion_encoder_size"], + motion_style_dim=config["motion_style_dim"], + motion_dim=config["motion_dim"], + motion_encoder_dim=config["motion_encoder_dim"], + face_encoder_hidden_dim=config["face_encoder_hidden_dim"], + face_encoder_num_heads=config["face_encoder_num_heads"], + inject_face_latents_blocks=config["inject_face_latents_blocks"], + flash_min_seq_length=1, + scan_layers=False, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + + transfer_transformer_weights(pt_model, jax_model) + + batch_size = 1 + num_frames = 2 + height = 8 + width = 8 + + hidden_states_pt = torch.randn(batch_size, config["in_channels"], num_frames, height, width) + encoder_hidden_states_pt = torch.randn(batch_size, 10, config["text_dim"]) + pose_hidden_states_pt = torch.randn(batch_size, config["latent_channels"], num_frames - 1, height, width) + face_pixel_values_pt = torch.randn( + batch_size, + 3, + num_frames, + config["motion_encoder_size"], + config["motion_encoder_size"], + ) + + hidden_states_jax = jnp.array(hidden_states_pt.numpy()) + encoder_hidden_states_jax = jnp.array(encoder_hidden_states_pt.numpy()) + pose_hidden_states_jax = jnp.array(pose_hidden_states_pt.numpy()) + face_pixel_values_jax = jnp.array(face_pixel_values_pt.numpy()) + + timesteps = [1.0, 250.0, 500.0] + for ts_val in timesteps: + timestep_pt = torch.tensor([ts_val]) + timestep_jax = jnp.array(timestep_pt.numpy()) + + with torch.no_grad(): + pt_out = pt_model( + hidden_states_pt, + timestep=timestep_pt, + encoder_hidden_states=encoder_hidden_states_pt, + pose_hidden_states=pose_hidden_states_pt, + face_pixel_values=face_pixel_values_pt, + return_dict=False, + ) + if isinstance(pt_out, tuple): + pt_out = pt_out[0] + elif isinstance(pt_out, dict): + pt_out = pt_out["sample"] + + with mesh: + jax_out = jax_model( + hidden_states_jax, + timestep=timestep_jax, + encoder_hidden_states=encoder_hidden_states_jax, + pose_hidden_states=pose_hidden_states_jax, + face_pixel_values=face_pixel_values_jax, + return_dict=False, + ) + if isinstance(jax_out, (list, tuple)): + jax_out = jax_out[0] + elif isinstance(jax_out, dict): + jax_out = jax_out["sample"] + + np_pt = pt_out.detach().numpy() + np_jax = np.array(jax_out) + + np.testing.assert_equal(np_pt.shape, np_jax.shape) + np.testing.assert_allclose(np_pt, np_jax, rtol=1e-3, atol=5e-3)