Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import jax.numpy as jnp
from jax.sharding import Mesh
from maxdiffusion import pyconfig, max_utils
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from maxdiffusion.video_processor import VideoProcessor
from maxdiffusion.checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1

import tensorflow as tf

Expand Down Expand Up @@ -80,7 +80,13 @@ def text_encode(pipeline, prompt: Union[str, List[str]]):
def vae_encode(video, rng, vae, vae_cache):
latent = vae.encode(video, feat_cache=vae_cache)
latent = latent.latent_dist.sample(rng)
return latent
latents = jnp.transpose(latent, (0, 4, 1, 2, 3))
latents_mean = jnp.array(vae.latents_mean).reshape(1, vae.z_dim, 1, 1, 1)
latents_std = jnp.array(vae.latents_std).reshape(1, vae.z_dim, 1, 1, 1)

# Apply normalization: (x - mean) / std
latents = (latents - latents_mean) / latents_std
return latents


def generate_dataset(config, pipeline):
Expand Down Expand Up @@ -121,7 +127,6 @@ def generate_dataset(config, pipeline):
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
with mesh:
latents = p_vae_encode(video=video, rng=new_rng)
latents = jnp.transpose(latents, (0, 4, 1, 2, 3))
encoder_hidden_states = text_encode(pipeline, text)
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
writer.write(create_example(latent, encoder_hidden_state))
Expand All @@ -138,8 +143,10 @@ def generate_dataset(config, pipeline):


def run(config):
pipeline = WanPipeline.from_pretrained(config, load_transformer=False)
checkpoint_loader = WanCheckpointer2_1(config=config)
pipeline, _, _ = checkpoint_loader.load_checkpoint()
# Don't need the transformer for preprocessing.
del pipeline.transformer
generate_dataset(config, pipeline)


Expand Down
105 changes: 105 additions & 0 deletions src/maxdiffusion/tests/data_processing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import pytest
import functools
import jax
import jax.numpy as jnp
from flax.linen import partitioning as nn_partitioning
from jax.sharding import Mesh
from .. import pyconfig
from ..max_utils import (
create_device_mesh,
)
import numpy as np
import unittest
from ..data_preprocessing.wan_txt2vid_data_preprocessing import vae_encode
from ..checkpointing.wan_checkpointer_2_1 import WanCheckpointer2_1
from ..utils import load_video
from ..video_processor import VideoProcessor
import flax

THIS_DIR = os.path.dirname(os.path.abspath(__file__))

CACHE_T = 2

IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true"

flax.config.update("flax_always_shard_variable", False)


class DataProcessingTest(unittest.TestCase):

def setUp(self):
DataProcessingTest.dummy_data = {}
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
],
unittest=True,
)
config = pyconfig.config
self.config = config
devices_array = create_device_mesh(config)
self.mesh = Mesh(devices_array, config.mesh_axes)

@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
def test_wan_vae_encode_normalization(self):
"""Test wan vae encode function normalization"""
pyconfig.initialize(
[
None,
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
],
unittest=True,
)
config = pyconfig.config
devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
checkpoint_loader = WanCheckpointer2_1(config=config)
pipeline, _, _ = checkpoint_loader.load_checkpoint()

vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)

video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
video = load_video(video_path)
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width)]
videos = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))

rng = jax.random.key(config.seed)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
latents = p_vae_encode(videos, rng=rng)
# 1. Verify Channel Count (Wan 2.1 requires 16)
self.assertEqual(latents.shape[1], 16, f"Expected 16 channels, got {latents.shape[1]}")

# 2. Verify Global Stats
# We expect mean near 0 and variance near 1.
# We use a threshold (e.g., 0.15) since this is just one video.
global_mean = jnp.mean(latents)
global_var = jnp.var(latents)

self.assertLess(abs(global_mean), 0.2, f"Global mean {global_mean} is too far from 0")
self.assertAlmostEqual(global_var, 1.0, delta=0.2, msg=f"Global variance {global_var} is too far from 1.0")

# 3. Verify Channel-wise Range
# Ensure no channel is completely "dead" or "exploding"
channel_vars = jnp.var(latents, axis=(0, 2, 3, 4))
self.assertTrue(jnp.all(channel_vars > 0.1), "One or more channels have near-zero variance")
self.assertTrue(jnp.all(channel_vars < 5.0), "One or more channels have exploding variance")
Loading