diff --git a/src/maxdiffusion/tests/legacy_hf_tests/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/tests/conftest.py b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py similarity index 79% rename from tests/conftest.py rename to src/maxdiffusion/tests/legacy_hf_tests/conftest.py index 42d0bac8..730f6f27 100644 --- a/tests/conftest.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py @@ -31,14 +31,14 @@ def pytest_addoption(parser): - from maxdiffusion.utils.testing_utils import pytest_addoption_shared + from maxdiffusion.utils.testing_utils import pytest_addoption_shared - pytest_addoption_shared(parser) + pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main + from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main - make_reports = terminalreporter.config.getoption("--make-reports") - if make_reports: - pytest_terminal_summary_main(terminalreporter, id=make_reports) + make_reports = terminalreporter.config.getoption("--make-reports") + if make_reports: + pytest_terminal_summary_main(terminalreporter, id=make_reports) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py new file mode 100644 index 00000000..1caabdbf --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py @@ -0,0 +1,83 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import inspect + +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + + +@require_flax +class FlaxModelTesterMixin: + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) + + output = model.apply(variables, inputs_dict["sample"]) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py new file mode 100644 index 00000000..f514f708 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py @@ -0,0 +1,118 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import gc +import unittest + +from maxdiffusion import FlaxUNet2DConditionModel +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import load_hf_numpy, require_flax, slow +from parameterized import parameterized + + +if is_flax_available(): + import jax + import jax.numpy as jnp + + +@slow +@require_flax +class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): + + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + revision = "bf16" if fp16 else None + + model, params = FlaxUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", dtype=dtype, revision=revision) + return model, params + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off + [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], + [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], + [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], + [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], + # fmt: on + ] + ) + def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off + [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], + [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], + [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], + [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], + # fmt: on + ] + ) + def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py new file mode 100644 index 00000000..295ed508 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py @@ -0,0 +1,55 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest + +from maxdiffusion import FlaxAutoencoderKL +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import require_flax + +from .test_modeling_common_flax import FlaxModelTesterMixin + + +if is_flax_available(): + import jax + + +@require_flax +class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): + model_class = FlaxAutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + prng_key = jax.random.PRNGKey(0) + image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) + + return {"sample": image, "prng_key": prng_key} + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py new file mode 100644 index 00000000..e7c0b714 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_01.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_01.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_02.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_02.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_03.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_03.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_03.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_03.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_04.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_04.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_04.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_04.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_05.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_05.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_05.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_05.npy diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py new file mode 100644 index 00000000..45583a2f --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py @@ -0,0 +1,939 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import tempfile +import unittest +from typing import Dict, List, Tuple + +from maxdiffusion import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler +from maxdiffusion.utils import is_flax_available +from maxdiffusion.utils.testing_utils import require_flax + + +if is_flax_available(): + import jax + import jax.numpy as jnp + from jax import random + + jax_device = jax.default_backend() + + +@require_flax +class FlaxSchedulerCommonTest(unittest.TestCase): + scheduler_classes = () + forward_default_kwargs = () + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + key1, key2 = random.split(random.PRNGKey(0)) + sample = random.uniform(key1, (batch_size, num_channels, height, width)) + + return sample, key2 + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = jnp.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + return jnp.transpose(sample, (3, 0, 1, 2)) + + def get_scheduler_config(self): + raise NotImplementedError + + def dummy_model(self): + def model(sample, t, *args): + return sample * t / (t + 1) + + return model + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 + + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) + + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) + + +@require_flax +class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDPMScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "variance_type": "fixed_small", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [1, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_variance_type(self): + for variance in ["fixed_small", "fixed_large", "other"]: + self.check_over_configs(variance_type=variance) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_time_indices(self): + for t in [0, 500, 999]: + self.check_over_forward(time_step=t) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + key1, key2 = random.split(random.PRNGKey(0)) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + output = scheduler.step(state, residual, t, sample, key1) + pred_prev_sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 257.28717) < 1.5e-2 + assert abs(result_mean - 0.33500) < 2e-5 + else: + assert abs(result_sum - 257.33148) < 1e-2 + assert abs(result_mean - 0.335057) < 1e-3 + + +@require_flax +class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxDDIMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + key1, key2 = random.split(random.PRNGKey(0)) + + num_inference_steps = 10 + + model = self.dummy_model() + sample = self.dummy_sample_deter + + state = scheduler.set_timesteps(state, num_inference_steps) + + for t in state.timesteps: + residual = model(sample, t) + output = scheduler.step(state, residual, t, sample) + sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + return sample + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 500, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() + + def test_steps_trailing(self): + self.check_over_configs(timestep_spacing="trailing") + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(timestep_spacing="trailing") + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([999, 799, 599, 399, 199])).all() + + def test_steps_leading(self): + self.check_over_configs(timestep_spacing="leading") + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(timestep_spacing="leading") + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([800, 600, 400, 200, 0])).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 10, 49]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 172.0067) < 1e-2 + assert abs(result_mean - 0.223967) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 149.82944) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + else: + assert abs(result_sum - 149.8295) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + pass + # FIXME: both result_sum and result_mean are nan on TPU + # assert jnp.isnan(result_sum) + # assert jnp.isnan(result_mean) + else: + assert abs(result_sum - 149.0784) < 1e-2 + assert abs(result_mean - 0.1941) < 1e-3 + + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + +@require_flax +class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + state = state.replace(ets=dummy_past_residuals[:]) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + new_state = new_state.replace(ets=dummy_past_residuals[:]) + + (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + pass + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + + # copy over dummy past residual (must be after setting timesteps) + new_state.replace(ets=dummy_past_residuals[:]) + + output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + for i, t in enumerate(state.prk_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_prk(state, residual, t, sample) + + for i, t in enumerate(state.plms_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_plms(state, residual, t, sample) + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + state = state.replace(ets=dummy_past_residuals[:]) + + output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 10, shape=()) + assert jnp.equal( + state.timesteps, + jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(state.prk_timesteps[:2]): + sample, state = scheduler.step_prk(state, residual, t, sample) + + def test_inference_plms_no_past_residuals(self): + with self.assertRaises(ValueError): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 198.1275) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + else: + assert abs(result_sum - 198.1318) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 8e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9466) < 1e-2 + assert abs(result_mean - 0.24342) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 8e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9482) < 1e-2 + assert abs(result_mean - 0.2434) < 1e-3 diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py new file mode 100644 index 00000000..821adcfe --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py @@ -0,0 +1,104 @@ +""" +Copyright 2025 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import jax.numpy as jnp +from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler +import os +from maxdiffusion import max_logging +import torch +import unittest +from absl.testing import absltest +import numpy as np + + +class rfTest(unittest.TestCase): + + def test_rf_steps(self): + # --- Simulation Parameters --- + latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) + inference_steps_count = 5 # Number of steps for the denoising process + + # --- Run the Simulation --- + max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") + + seed = 42 + device = "cpu" + max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") + + generator = torch.Generator(device=device).manual_seed(seed) + + # 1. Instantiate the scheduler + config = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": None, + "base_resolution": None, + "sampler": "LinearQuadratic", + } + flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) + + # 2. Create and set initial state for the scheduler + flax_state = flax_scheduler.create_state() + flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) + max_logging.log("\nScheduler initialized.") + max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") + + # 3. Prepare the initial noisy latent sample + # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) + # For simulation, we'll generate it. + + sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) + max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") + + # 4. Simulate the denoising loop + max_logging.log("\nStarting denoising loop:") + for i, t in enumerate(flax_state.timesteps): + max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") + + # Simulate model_output (e.g., noise prediction from a UNet) + model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) + + # Call the scheduler's step function + scheduler_output = flax_scheduler.step( + state=flax_state, + model_output=model_output, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + + sample = scheduler_output.prev_sample # Update the sample for the next step + flax_state = scheduler_output.state # Update the state for the next step + + # Compare with pytorch implementation + base_dir = os.path.dirname(__file__) + ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") + ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") + if os.path.exists(ref_filename): + pt_sample = np.load(ref_filename) + torch.testing.assert_close(np.array(sample), pt_sample) + else: + max_logging.log(f"Warning: Reference file not found: {ref_filename}") + + max_logging.log("\nDenoising loop completed.") + max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") + max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") + + max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") + + +if __name__ == "__main__": + absltest.main() diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py new file mode 100644 index 00000000..657a5bb8 --- /dev/null +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py @@ -0,0 +1,658 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/tests/schedulers/test_scheduler_unipc.py + +import tempfile + +import torch +import jax.numpy as jnp +from typing import Dict, List, Tuple + +from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import ( + FlaxUniPCMultistepScheduler, +) +from maxdiffusion import FlaxDPMSolverMultistepScheduler + +from .test_scheduler_flax import FlaxSchedulerCommonTest + + +class FlaxUniPCMultistepSchedulerTest(FlaxSchedulerCommonTest): + scheduler_classes = (FlaxUniPCMultistepScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample = jnp.asarray(sample) + return jax_sample + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "solver_type": "bh2", + "final_sigmas_type": "sigma_min", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + t = output_state.timesteps[i] + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + + # What is this doing? + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + + t = output_state.timesteps[i] + + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def full_loop(self, scheduler=None, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + if scheduler is None: + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + else: + state = scheduler.create_state() # Ensure state is fresh for the loop + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + + # scheduler.step in common test receives state, residual, t, sample + step_output = scheduler.step( + state=state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + sample = step_output.prev_sample + state = step_output.state # Update state for next iteration + + return sample + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() # Create initial state + + sample = self.dummy_sample # Get sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # Copy over dummy past residuals (must be done after set_timesteps) + dummy_past_model_outputs = [ + 0.2 * sample, + 0.15 * sample, + 0.10 * sample, + ] + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + + time_step_0 = state.timesteps[5] + time_step_1 = state.timesteps[6] + + output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample + output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler_config = self.get_scheduler_config() + scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) + sample_1 = self.full_loop(scheduler=scheduler_1) + result_mean_1 = jnp.mean(jnp.abs(sample_1)) + + assert abs(result_mean_1.item() - 0.2464) < 1e-3 + + scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance + sample_2 = self.full_loop(scheduler=scheduler_2) + result_mean_2 = jnp.mean(jnp.abs(sample_2)) + + self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency + + assert abs(result_mean_2.item() - 0.2464) < 1e-3 + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + with self.assertRaises(NotImplementedError): + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + solver_order=order, + solver_type=solver_type, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_rescale_betas_zero_snr(self): + for rescale_zero_terminal_snr in [True, False]: + self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) + + def test_solver_order_and_type(self): + for solver_type in ["bh1", "bh2"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(time_step=0, num_inference_steps=num_inference_steps) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2464) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2925) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1966) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: + scheduler_config = self.get_scheduler_config( + thresholding=False, + dynamic_thresholding_ratio=0, + prediction_type=prediction_type, + solver_order=order, + solver_type=solver_type, + ) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.astype(jnp.bfloat16) + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + # sample is casted to fp32 inside step and output should be fp32. + self.assertEqual(sample.dtype, jnp.float32) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index:] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" + assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" + + +class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + width = 8 + + torch_sample = torch.rand((batch_size, num_channels, width)) + jax_sample = jnp.asarray(torch_sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + jax_sample = jnp.asarray(sample) + return jax_sample + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2898) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1944) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index:] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" + assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" + + def test_beta_sigmas(self): + # self.check_over_configs(use_beta_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_beta_sigmas=True) + + def test_exponential_sigmas(self): + # self.check_over_configs(use_exponential_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_exponential_sigmas=True) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index b392d39a..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ diff --git a/tests/models/__init__.py b/tests/models/__init__.py deleted file mode 100644 index b392d39a..00000000 --- a/tests/models/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ diff --git a/tests/models/test_modeling_common_flax.py b/tests/models/test_modeling_common_flax.py deleted file mode 100644 index 0fa55dcf..00000000 --- a/tests/models/test_modeling_common_flax.py +++ /dev/null @@ -1,82 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import inspect - -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import require_flax - - -if is_flax_available(): - import jax - - -@require_flax -class FlaxModelTesterMixin: - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) - - output = model.apply(variables, inputs_dict["sample"]) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["norm_num_groups"] = 16 - init_dict["block_out_channels"] = (16, 32) - - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) - - output = model.apply(variables, inputs_dict["sample"]) - - if isinstance(output, dict): - output = output.sample - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - def test_deprecated_kwargs(self): - has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters - has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" - " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" - f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" - " from `_deprecated_kwargs = []`" - ) diff --git a/tests/models/test_models_unet_2d_flax.py b/tests/models/test_models_unet_2d_flax.py deleted file mode 100644 index ed1c8d39..00000000 --- a/tests/models/test_models_unet_2d_flax.py +++ /dev/null @@ -1,119 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import gc -import unittest - -from maxdiffusion import FlaxUNet2DConditionModel -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import load_hf_numpy, require_flax, slow -from parameterized import parameterized - - -if is_flax_available(): - import jax - import jax.numpy as jnp - - -@slow -@require_flax -class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): - def get_file_format(self, seed, shape): - return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" - - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return image - - def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - revision = "bf16" if fp16 else None - - model, params = FlaxUNet2DConditionModel.from_pretrained( - model_id, subfolder="unet", dtype=dtype, revision=revision - ) - return model, params - - def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return hidden_states - - @parameterized.expand( - [ - # fmt: off - [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], - [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], - [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], - [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], - # fmt: on - ] - ) - def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) - latents = self.get_latents(seed, fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) - - @parameterized.expand( - [ - # fmt: off - [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], - [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], - [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], - [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], - # fmt: on - ] - ) - def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) - latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/tests/models/test_models_vae_flax.py b/tests/models/test_models_vae_flax.py deleted file mode 100644 index b00bd6e9..00000000 --- a/tests/models/test_models_vae_flax.py +++ /dev/null @@ -1,55 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import unittest - -from maxdiffusion import FlaxAutoencoderKL -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import require_flax - -from .test_modeling_common_flax import FlaxModelTesterMixin - - -if is_flax_available(): - import jax - - -@require_flax -class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): - model_class = FlaxAutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - prng_key = jax.random.PRNGKey(0) - image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) - - return {"sample": image, "prng_key": prng_key} - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict diff --git a/tests/schedulers/__init__.py b/tests/schedulers/__init__.py deleted file mode 100644 index b392d39a..00000000 --- a/tests/schedulers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -""" - Copyright 2024 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py deleted file mode 100644 index eab5cb91..00000000 --- a/tests/schedulers/test_scheduler_flax.py +++ /dev/null @@ -1,939 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import inspect -import tempfile -import unittest -from typing import Dict, List, Tuple - -from maxdiffusion import FlaxDDIMScheduler, FlaxDDPMScheduler, FlaxPNDMScheduler -from maxdiffusion.utils import is_flax_available -from maxdiffusion.utils.testing_utils import require_flax - - -if is_flax_available(): - import jax - import jax.numpy as jnp - from jax import random - - jax_device = jax.default_backend() - - -@require_flax -class FlaxSchedulerCommonTest(unittest.TestCase): - scheduler_classes = () - forward_default_kwargs = () - - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - key1, key2 = random.split(random.PRNGKey(0)) - sample = random.uniform(key1, (batch_size, num_channels, height, width)) - - return sample, key2 - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = jnp.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - return jnp.transpose(sample, (3, 0, 1, 2)) - - def get_scheduler_config(self): - raise NotImplementedError - - def dummy_model(self): - def model(sample, t, *args): - return sample * t / (t + 1) - - return model - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_deprecated_kwargs(self): - for scheduler_class in self.scheduler_classes: - has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters - has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" - " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" - f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" - " deprecated argument from `_deprecated_kwargs = []`" - ) - - -@require_flax -class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDPMScheduler,) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "variance_type": "fixed_small", - "clip_sample": True, - } - - config.update(**kwargs) - return config - - def test_timesteps(self): - for timesteps in [1, 5, 100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_variance_type(self): - for variance in ["fixed_small", "fixed_large", "other"]: - self.check_over_configs(variance_type=variance) - - def test_clip_sample(self): - for clip_sample in [True, False]: - self.check_over_configs(clip_sample=clip_sample) - - def test_time_indices(self): - for t in [0, 500, 999]: - self.check_over_forward(time_step=t) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_trained_timesteps = len(scheduler) - - model = self.dummy_model() - sample = self.dummy_sample_deter - key1, key2 = random.split(random.PRNGKey(0)) - - for t in reversed(range(num_trained_timesteps)): - # 1. predict noise residual - residual = model(sample, t) - - # 2. predict previous mean of sample x_t-1 - output = scheduler.step(state, residual, t, sample, key1) - pred_prev_sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - # if t > 0: - # noise = self.dummy_sample_deter - # variance = scheduler.get_variance(t) ** (0.5) * noise - # - # sample = pred_prev_sample + variance - sample = pred_prev_sample - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 257.28717) < 1.5e-2 - assert abs(result_mean - 0.33500) < 2e-5 - else: - assert abs(result_sum - 257.33148) < 1e-2 - assert abs(result_mean - 0.335057) < 1e-3 - - -@require_flax -class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDIMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - key1, key2 = random.split(random.PRNGKey(0)) - - num_inference_steps = 10 - - model = self.dummy_model() - sample = self.dummy_sample_deter - - state = scheduler.set_timesteps(state, num_inference_steps) - - for t in state.timesteps: - residual = model(sample, t) - output = scheduler.step(state, residual, t, sample) - sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - return sample - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 500, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() - - def test_steps_trailing(self): - self.check_over_configs(timestep_spacing="trailing") - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(timestep_spacing="trailing") - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([999, 799, 599, 399, 199])).all() - - def test_steps_leading(self): - self.check_over_configs(timestep_spacing="leading") - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(timestep_spacing="leading") - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([800, 600, 400, 200, 0])).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 10, 49]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum - 172.0067) < 1e-2 - assert abs(result_mean - 0.223967) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 149.82944) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - else: - assert abs(result_sum - 149.8295) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - pass - # FIXME: both result_sum and result_mean are nan on TPU - # assert jnp.isnan(result_sum) - # assert jnp.isnan(result_mean) - else: - assert abs(result_sum - 149.0784) < 1e-2 - assert abs(result_mean - 0.1941) < 1e-3 - - def test_prediction_type(self): - for prediction_type in ["epsilon", "sample", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) - - -@require_flax -class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxPNDMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - state = state.replace(ets=dummy_past_residuals[:]) - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - new_state = new_state.replace(ets=dummy_past_residuals[:]) - - (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - pass - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residuals (must be after setting timesteps) - scheduler.ets = dummy_past_residuals[:] - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - # copy over dummy past residuals - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residual (must be after setting timesteps) - new_state.replace(ets=dummy_past_residuals[:]) - - output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - for i, t in enumerate(state.prk_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_prk(state, residual, t, sample) - - for i, t in enumerate(state.plms_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_plms(state, residual, t, sample) - - return sample - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - # copy over dummy past residuals (must be done after set_timesteps) - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - state = state.replace(ets=dummy_past_residuals[:]) - - output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 10, shape=()) - assert jnp.equal( - state.timesteps, - jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), - ).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 5, 10]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(num_inference_steps=num_inference_steps) - - def test_pow_of_3_inference_steps(self): - # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 - num_inference_steps = 27 - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # before power of 3 fix, would error on first step, so we only need to do two - for i, t in enumerate(state.prk_timesteps[:2]): - sample, state = scheduler.step_prk(state, residual, t, sample) - - def test_inference_plms_no_past_residuals(self): - with self.assertRaises(ValueError): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 198.1275) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - else: - assert abs(result_sum - 198.1318) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 8e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9466) < 1e-2 - assert abs(result_mean - 0.24342) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 8e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9482) < 1e-2 - assert abs(result_mean - 0.2434) < 1e-3 diff --git a/tests/schedulers/test_scheduler_rf.py b/tests/schedulers/test_scheduler_rf.py deleted file mode 100644 index 1d23880f..00000000 --- a/tests/schedulers/test_scheduler_rf.py +++ /dev/null @@ -1,99 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ -import jax.numpy as jnp -from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler -import os -from maxdiffusion import max_logging -import torch -import unittest -from absl.testing import absltest -import numpy as np - - - -class rfTest(unittest.TestCase): - - def test_rf_steps(self): - # --- Simulation Parameters --- - latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) - inference_steps_count = 5 # Number of steps for the denoising process - - # --- Run the Simulation --- - max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") - - seed = 42 - device = 'cpu' - max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") - - generator = torch.Generator(device=device).manual_seed(seed) - - # 1. Instantiate the scheduler - config = {'_class_name': 'RectifiedFlowScheduler', '_diffusers_version': '0.25.1', 'num_train_timesteps': 1000, 'shifting': None, 'base_resolution': None, 'sampler': 'LinearQuadratic'} - flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) - - # 2. Create and set initial state for the scheduler - flax_state = flax_scheduler.create_state() - flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) - max_logging.log("\nScheduler initialized.") - max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") - - # 3. Prepare the initial noisy latent sample - # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) - # For simulation, we'll generate it. - - sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) - max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") - - # 4. Simulate the denoising loop - max_logging.log("\nStarting denoising loop:") - for i, t in enumerate(flax_state.timesteps): - max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") - - # Simulate model_output (e.g., noise prediction from a UNet) - model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) - - # Call the scheduler's step function - scheduler_output = flax_scheduler.step( - state=flax_state, - model_output=model_output, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=sample, - return_dict=True # Return a SchedulerOutput dataclass - ) - - sample = scheduler_output.prev_sample # Update the sample for the next step - flax_state = scheduler_output.state # Update the state for the next step - - # Compare with pytorch implementation - base_dir = os.path.dirname(__file__) - ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") - ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") - if os.path.exists(ref_filename): - pt_sample = np.load(ref_filename) - torch.testing.assert_close(np.array(sample), pt_sample) - else: - max_logging.log(f"Warning: Reference file not found: {ref_filename}") - - - max_logging.log("\nDenoising loop completed.") - max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") - max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") - - max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py deleted file mode 100644 index d401f54f..00000000 --- a/tests/schedulers/test_scheduler_unipc.py +++ /dev/null @@ -1,680 +0,0 @@ -# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: check https://arxiv.org/abs/2302.04867 and https://github.com/wl-zhao/UniPC for more info -# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/tests/schedulers/test_scheduler_unipc.py - -import tempfile - -import torch -import jax.numpy as jnp -from typing import Dict, List, Tuple - -from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import ( - FlaxUniPCMultistepScheduler, -) -from maxdiffusion import FlaxDPMSolverMultistepScheduler - -from .test_scheduler_flax import FlaxSchedulerCommonTest - - -class FlaxUniPCMultistepSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxUniPCMultistepScheduler,) - forward_default_kwargs = (("num_inference_steps", 25),) - - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - sample = torch.rand((batch_size, num_channels, height, width)) - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_noise_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = torch.arange(num_elems).flip(-1) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) - - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = torch.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) - - jax_sample= jnp.asarray(sample) - return jax_sample - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "solver_order": 2, - "solver_type": "bh2", - "final_sigmas_type": "sigma_min", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample = self.dummy_sample - residual = 0.1 * sample - dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - state = scheduler.set_timesteps( - state, num_inference_steps, sample.shape - ) - new_state = new_scheduler.set_timesteps( - new_state, num_inference_steps, sample.shape - ) - # copy over dummy past residuals - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - # Copy over dummy past residuals to new_state as well - new_state = new_state.replace(model_outputs=initial_model_outputs) - - - output_sample, output_state = sample, state - new_output_sample, new_output_state = sample, new_state - # Need to iterate through the steps as UniPC maintains history over steps - # The loop for solver_order + 1 steps is crucial for UniPC's history logic. - for i in range(time_step, time_step + scheduler.config.solver_order + 1): - # Ensure time_step + i is within the bounds of timesteps - if i >= len(output_state.timesteps): - break - t = output_state.timesteps[i] - step_output = scheduler.step( - state=output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - output_sample = step_output.prev_sample - output_state = step_output.state - - new_step_output = new_scheduler.step( - state=new_output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=new_output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - new_output_sample = new_step_output.prev_sample - new_output_state = new_step_output.state - - self.assertTrue( - jnp.allclose(output_sample, new_output_sample, atol=1e-5), - "Scheduler outputs are not identical", - ) - # Also assert that states are identical - self.assertEqual(output_state.step_index, new_output_state.step_index) - self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) - self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) - # Comparing model_outputs (history) directly: - if output_state.model_outputs is not None and new_output_state.model_outputs is not None: - for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): - self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample = self.dummy_sample - residual = 0.1 * sample - dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps( - state, num_inference_steps, sample.shape - ) - - # copy over dummy past residuals - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - - # What is this doing? - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): - new_state = new_scheduler.set_timesteps( - new_state, num_inference_steps, sample.shape - ) - # Copy over dummy past residuals to new_state as well - new_state = new_state.replace(model_outputs=initial_model_outputs) - - - output_sample, output_state = sample, state - new_output_sample, new_output_state = sample, new_state - - # Need to iterate through the steps as UniPC maintains history over steps - # The loop for solver_order + 1 steps is crucial for UniPC's history logic. - for i in range(time_step, time_step + scheduler.config.solver_order + 1): - # Ensure time_step + i is within the bounds of timesteps - if i >= len(output_state.timesteps): - break - - t = output_state.timesteps[i] - - step_output = scheduler.step( - state=output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - **kwargs, - ) - output_sample = step_output.prev_sample - output_state = step_output.state - - new_step_output = new_scheduler.step( - state=new_output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=new_output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - **kwargs, - ) - new_output_sample = new_step_output.prev_sample - new_output_state = new_step_output.state - - self.assertTrue( - jnp.allclose(output_sample, new_output_sample, atol=1e-5), - "Scheduler outputs are not identical", - ) - # Also assert that states are identical - self.assertEqual(output_state.step_index, new_output_state.step_index) - self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) - self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) - # Comparing model_outputs (history) directly: - if output_state.model_outputs is not None and new_output_state.model_outputs is not None: - for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): - self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") - - - def full_loop(self, scheduler=None, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - if scheduler is None: - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - else: - state = scheduler.create_state() # Ensure state is fresh for the loop - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - for i, t in enumerate(state.timesteps): - residual = model(sample, t) - - # scheduler.step in common test receives state, residual, t, sample - step_output = scheduler.step( - state=state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - sample = step_output.prev_sample - state = step_output.state # Update state for next iteration - - return sample - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() # Create initial state - - sample = self.dummy_sample # Get sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif ( - num_inference_steps is not None - and not hasattr(scheduler, "set_timesteps") - ): - kwargs["num_inference_steps"] = num_inference_steps - - # Copy over dummy past residuals (must be done after set_timesteps) - dummy_past_model_outputs = [ - 0.2 * sample, - 0.15 * sample, - 0.10 * sample, - ] - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - - time_step_0 = state.timesteps[5] - time_step_1 = state.timesteps[6] - - output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample - output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - - sample = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_switch(self): - # make sure that iterating over schedulers with same config names gives same results - # for defaults - scheduler_config = self.get_scheduler_config() - scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) - sample_1 = self.full_loop(scheduler=scheduler_1) - result_mean_1 = jnp.mean(jnp.abs(sample_1)) - - assert abs(result_mean_1.item() - 0.2464) < 1e-3 - - scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance - sample_2 = self.full_loop(scheduler=scheduler_2) - result_mean_2 = jnp.mean(jnp.abs(sample_2)) - - self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency - - assert abs(result_mean_2.item() - 0.2464) < 1e-3 - - def test_timesteps(self): - for timesteps in [25, 50, 100, 999, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_thresholding(self): - self.check_over_configs(thresholding=False) - for order in [1, 2, 3]: - for solver_type in ["bh1", "bh2"]: - for threshold in [0.5, 1.0, 2.0]: - for prediction_type in ["epsilon", "sample"]: - with self.assertRaises(NotImplementedError): - self.check_over_configs( - thresholding=True, - prediction_type=prediction_type, - sample_max_value=threshold, - solver_order=order, - solver_type=solver_type, - ) - - def test_prediction_type(self): - for prediction_type in ["epsilon", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) - - def test_rescale_betas_zero_snr(self): - for rescale_zero_terminal_snr in [True, False]: - self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) - - def test_solver_order_and_type(self): - for solver_type in ["bh1", "bh2"]: - for order in [1, 2, 3]: - for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - ) - sample = self.full_loop( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - ) - assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" - - - def test_lower_order_final(self): - self.check_over_configs(lower_order_final=True) - self.check_over_configs(lower_order_final=False) - - def test_inference_steps(self): - for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: - self.check_over_forward(time_step = 0, num_inference_steps=num_inference_steps) - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2464) < 1e-3 - - def test_full_loop_with_karras(self): - # sample = self.full_loop(use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.2925) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(use_karras_sigmas=True) - - def test_full_loop_with_v_prediction(self): - sample = self.full_loop(prediction_type="v_prediction") - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.1014) < 1e-3 - - def test_full_loop_with_karras_and_v_prediction(self): - # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.1966) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - - def test_fp16_support(self): - scheduler_class = self.scheduler_classes[0] - for order in [1, 2, 3]: - for solver_type in ["bh1", "bh2"]: - for prediction_type in ["epsilon", "sample", "v_prediction"]: - scheduler_config = self.get_scheduler_config( - thresholding=False, - dynamic_thresholding_ratio=0, - prediction_type=prediction_type, - solver_order=order, - solver_type=solver_type, - ) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter.astype(jnp.bfloat16) - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - for i, t in enumerate(state.timesteps): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - # sample is casted to fp32 inside step and output should be fp32. - self.assertEqual(sample.dtype, jnp.float32) - - def test_full_loop_with_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - t_start_index = 8 - - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - # add noise - noise = self.dummy_noise_deter - timesteps_for_noise = state.timesteps[t_start_index :] - sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) - - for i, t in enumerate(timesteps_for_noise): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" - assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" - - -class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - width = 8 - - torch_sample = torch.rand((batch_size, num_channels, width)) - jax_sample= jnp.asarray(torch_sample) - return jax_sample - - @property - def dummy_noise_deter(self): - batch_size = 4 - num_channels = 3 - width = 8 - - num_elems = batch_size * num_channels * width - sample = torch.arange(num_elems).flip(-1) - sample = sample.reshape(num_channels, width, batch_size) - sample = sample / num_elems - sample = sample.permute(2, 0, 1) - - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - width = 8 - - num_elems = batch_size * num_channels * width - sample = torch.arange(num_elems) - sample = sample.reshape(num_channels, width, batch_size) - sample = sample / num_elems - sample = sample.permute(2, 0, 1) - jax_sample= jnp.asarray(sample) - return jax_sample - - def test_switch(self): - # make sure that iterating over schedulers with same config names gives same results - # for defaults - scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) - sample = self.full_loop(scheduler=scheduler) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) - scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) - - sample = self.full_loop(scheduler=scheduler) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - def test_full_loop_with_karras(self): - # sample = self.full_loop(use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.2898) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(use_karras_sigmas=True) - - - def test_full_loop_with_v_prediction(self): - sample = self.full_loop(prediction_type="v_prediction") - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.1014) < 1e-3 - - def test_full_loop_with_karras_and_v_prediction(self): - # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.1944) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - - def test_full_loop_with_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - t_start_index = 8 - - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - # add noise - noise = self.dummy_noise_deter - timesteps_for_noise = state.timesteps[t_start_index :] - sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) - - for i, t in enumerate(timesteps_for_noise): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" - assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" - - def test_beta_sigmas(self): - # self.check_over_configs(use_beta_sigmas=True) - with self.assertRaises(NotImplementedError): - self.full_loop(use_beta_sigmas=True) - - def test_exponential_sigmas(self): - #self.check_over_configs(use_exponential_sigmas=True) - with self.assertRaises(NotImplementedError): - self.full_loop(use_exponential_sigmas=True)