From c8ba90736a653d96af729978d03e1d60a27ebed5 Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Thu, 29 Jan 2026 00:05:20 +0000 Subject: [PATCH 1/2] Move the top level 'tests/' into `src/maxdiffusion/tests/` as a legacy --- .../maxdiffusion/tests/legacy_hf_tests}/__init__.py | 0 .../maxdiffusion/tests/legacy_hf_tests}/conftest.py | 0 .../tests/legacy_hf_tests}/models/__init__.py | 0 .../models/test_modeling_common_flax.py | 0 .../models/test_models_unet_2d_flax.py | 0 .../legacy_hf_tests}/models/test_models_vae_flax.py | 0 .../tests/legacy_hf_tests}/schedulers/__init__.py | 0 .../rf_scheduler_test_ref/step_00_noisy_input.npy | Bin .../schedulers/rf_scheduler_test_ref/step_01.npy | Bin .../schedulers/rf_scheduler_test_ref/step_02.npy | Bin .../schedulers/rf_scheduler_test_ref/step_03.npy | Bin .../schedulers/rf_scheduler_test_ref/step_04.npy | Bin .../schedulers/rf_scheduler_test_ref/step_05.npy | Bin .../schedulers/test_scheduler_flax.py | 0 .../schedulers/test_scheduler_rf.py | 0 .../schedulers/test_scheduler_unipc.py | 0 16 files changed, 0 insertions(+), 0 deletions(-) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/__init__.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/conftest.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/models/__init__.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/models/test_modeling_common_flax.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/models/test_models_unet_2d_flax.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/models/test_models_vae_flax.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/__init__.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/rf_scheduler_test_ref/step_01.npy (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/rf_scheduler_test_ref/step_02.npy (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/rf_scheduler_test_ref/step_03.npy (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/rf_scheduler_test_ref/step_04.npy (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/rf_scheduler_test_ref/step_05.npy (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/test_scheduler_flax.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/test_scheduler_rf.py (100%) rename {tests => src/maxdiffusion/tests/legacy_hf_tests}/schedulers/test_scheduler_unipc.py (100%) diff --git a/tests/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py similarity index 100% rename from tests/__init__.py rename to src/maxdiffusion/tests/legacy_hf_tests/__init__.py diff --git a/tests/conftest.py b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py similarity index 100% rename from tests/conftest.py rename to src/maxdiffusion/tests/legacy_hf_tests/conftest.py diff --git a/tests/models/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py similarity index 100% rename from tests/models/__init__.py rename to src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py diff --git a/tests/models/test_modeling_common_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py similarity index 100% rename from tests/models/test_modeling_common_flax.py rename to src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py diff --git a/tests/models/test_models_unet_2d_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py similarity index 100% rename from tests/models/test_models_unet_2d_flax.py rename to src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py diff --git a/tests/models/test_models_vae_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py similarity index 100% rename from tests/models/test_models_vae_flax.py rename to src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py diff --git a/tests/schedulers/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py similarity index 100% rename from tests/schedulers/__init__.py rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py diff --git a/tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_00_noisy_input.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_01.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_01.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_01.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_02.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_02.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_02.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_03.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_03.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_03.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_03.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_04.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_04.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_04.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_04.npy diff --git a/tests/schedulers/rf_scheduler_test_ref/step_05.npy b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_05.npy similarity index 100% rename from tests/schedulers/rf_scheduler_test_ref/step_05.npy rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/rf_scheduler_test_ref/step_05.npy diff --git a/tests/schedulers/test_scheduler_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py similarity index 100% rename from tests/schedulers/test_scheduler_flax.py rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py diff --git a/tests/schedulers/test_scheduler_rf.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py similarity index 100% rename from tests/schedulers/test_scheduler_rf.py rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py diff --git a/tests/schedulers/test_scheduler_unipc.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py similarity index 100% rename from tests/schedulers/test_scheduler_unipc.py rename to src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py From 1fc4284920febe760094be25bbaa874713560c20 Mon Sep 17 00:00:00 2001 From: michelle-yooh Date: Thu, 29 Jan 2026 00:15:52 +0000 Subject: [PATCH 2/2] Format by pyink --- .../tests/legacy_hf_tests/__init__.py | 22 +- .../tests/legacy_hf_tests/conftest.py | 12 +- .../tests/legacy_hf_tests/models/__init__.py | 22 +- .../models/test_modeling_common_flax.py | 107 +- .../models/test_models_unet_2d_flax.py | 181 +- .../models/test_models_vae_flax.py | 72 +- .../legacy_hf_tests/schedulers/__init__.py | 22 +- .../schedulers/test_scheduler_flax.py | 1638 ++++++++--------- .../schedulers/test_scheduler_rf.py | 163 +- .../schedulers/test_scheduler_unipc.py | 1230 ++++++------- 10 files changed, 1726 insertions(+), 1743 deletions(-) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/__init__.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/legacy_hf_tests/conftest.py b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py index 42d0bac8..730f6f27 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/conftest.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/conftest.py @@ -31,14 +31,14 @@ def pytest_addoption(parser): - from maxdiffusion.utils.testing_utils import pytest_addoption_shared + from maxdiffusion.utils.testing_utils import pytest_addoption_shared - pytest_addoption_shared(parser) + pytest_addoption_shared(parser) def pytest_terminal_summary(terminalreporter): - from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main + from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main - make_reports = terminalreporter.config.getoption("--make-reports") - if make_reports: - pytest_terminal_summary_main(terminalreporter, id=make_reports) + make_reports = terminalreporter.config.getoption("--make-reports") + if make_reports: + pytest_terminal_summary_main(terminalreporter, id=make_reports) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py index 0fa55dcf..1caabdbf 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_modeling_common_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import inspect @@ -21,62 +21,63 @@ if is_flax_available(): - import jax + import jax @require_flax class FlaxModelTesterMixin: - def test_output(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) - output = model.apply(variables, inputs_dict["sample"]) + output = model.apply(variables, inputs_dict["sample"]) - if isinstance(output, dict): - output = output.sample + if isinstance(output, dict): + output = output.sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + def test_forward_with_norm_groups(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - init_dict["norm_num_groups"] = 16 - init_dict["block_out_channels"] = (16, 32) + init_dict["norm_num_groups"] = 16 + init_dict["block_out_channels"] = (16, 32) - model = self.model_class(**init_dict) - variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) - jax.lax.stop_gradient(variables) + model = self.model_class(**init_dict) + variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) + jax.lax.stop_gradient(variables) - output = model.apply(variables, inputs_dict["sample"]) + output = model.apply(variables, inputs_dict["sample"]) - if isinstance(output, dict): - output = output.sample + if isinstance(output, dict): + output = output.sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + self.assertIsNotNone(output) + expected_shape = inputs_dict["sample"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - def test_deprecated_kwargs(self): - has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters - has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 + def test_deprecated_kwargs(self): + has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters + has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0 - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" - " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are" + " no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" - " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" - f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" - " from `_deprecated_kwargs = []`" - ) + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs" + " under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to" + f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument" + " from `_deprecated_kwargs = []`" + ) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py index ed1c8d39..f514f708 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_unet_2d_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import gc import unittest @@ -24,96 +24,95 @@ if is_flax_available(): - import jax - import jax.numpy as jnp + import jax + import jax.numpy as jnp @slow @require_flax class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase): - def get_file_format(self, seed, shape): - return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" - - def tearDown(self): - # clean up the VRAM after each test - super().tearDown() - gc.collect() - - def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return image - - def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - revision = "bf16" if fp16 else None - - model, params = FlaxUNet2DConditionModel.from_pretrained( - model_id, subfolder="unet", dtype=dtype, revision=revision - ) - return model, params - - def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): - dtype = jnp.bfloat16 if fp16 else jnp.float32 - hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) - return hidden_states - - @parameterized.expand( - [ - # fmt: off + + def get_file_format(self, seed, shape): + return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy" + + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + + def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return image + + def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + revision = "bf16" if fp16 else None + + model, params = FlaxUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", dtype=dtype, revision=revision) + return model, params + + def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): + dtype = jnp.bfloat16 if fp16 else jnp.float32 + hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype) + return hidden_states + + @parameterized.expand( + [ + # fmt: off [83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]], [17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]], [8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]], [3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]], - # fmt: on - ] - ) - def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) - latents = self.get_latents(seed, fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) - - @parameterized.expand( - [ - # fmt: off + # fmt: on + ] + ) + def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True) + latents = self.get_latents(seed, fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + + @parameterized.expand( + [ + # fmt: off [83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]], [17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]], [8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]], [3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]], - # fmt: on - ] - ) - def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): - model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) - latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) - encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) - - sample = model.apply( - {"params": params}, - latents, - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=encoder_hidden_states, - ).sample - - assert sample.shape == latents.shape - - output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) - expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) - - # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware - assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) + # fmt: on + ] + ) + def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice): + model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True) + latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) + encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) + + sample = model.apply( + {"params": params}, + latents, + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=encoder_hidden_states, + ).sample + + assert sample.shape == latents.shape + + output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32) + expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32) + + # Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware + assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2) diff --git a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py index b00bd6e9..295ed508 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/models/test_models_vae_flax.py @@ -1,18 +1,18 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import unittest @@ -24,32 +24,32 @@ if is_flax_available(): - import jax + import jax @require_flax class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase): - model_class = FlaxAutoencoderKL - - @property - def dummy_input(self): - batch_size = 4 - num_channels = 3 - sizes = (32, 32) - - prng_key = jax.random.PRNGKey(0) - image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) - - return {"sample": image, "prng_key": prng_key} - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + model_class = FlaxAutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + prng_key = jax.random.PRNGKey(0) + image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes)) + + return {"sample": image, "prng_key": prng_key} + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 4, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py index b392d39a..e7c0b714 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py @@ -1,15 +1,15 @@ """ - Copyright 2024 Google LLC +Copyright 2024 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py index eab5cb91..45583a2f 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_flax.py @@ -23,917 +23,917 @@ if is_flax_available(): - import jax - import jax.numpy as jnp - from jax import random + import jax + import jax.numpy as jnp + from jax import random - jax_device = jax.default_backend() + jax_device = jax.default_backend() @require_flax class FlaxSchedulerCommonTest(unittest.TestCase): - scheduler_classes = () - forward_default_kwargs = () + scheduler_classes = () + forward_default_kwargs = () - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 - key1, key2 = random.split(random.PRNGKey(0)) - sample = random.uniform(key1, (batch_size, num_channels, height, width)) + key1, key2 = random.split(random.PRNGKey(0)) + sample = random.uniform(key1, (batch_size, num_channels, height, width)) - return sample, key2 + return sample, key2 - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 - num_elems = batch_size * num_channels * height * width - sample = jnp.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - return jnp.transpose(sample, (3, 0, 1, 2)) + num_elems = batch_size * num_channels * height * width + sample = jnp.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + return jnp.transpose(sample, (3, 0, 1, 2)) - def get_scheduler_config(self): - raise NotImplementedError + def get_scheduler_config(self): + raise NotImplementedError - def dummy_model(self): - def model(sample, t, *args): - return sample * t / (t + 1) + def dummy_model(self): + def model(sample, t, *args): + return sample * t / (t + 1) - return model + return model - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + num_inference_steps = kwargs.pop("num_inference_steps", None) - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + num_inference_steps = kwargs.pop("num_inference_steps", None) - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step(state, residual, time_step, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, key, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) - num_inference_steps = kwargs.pop("num_inference_steps", None) + for scheduler_class in self.scheduler_classes: + sample, key = self.dummy_sample + residual = 0.1 * sample - for scheduler_class in self.scheduler_classes: - sample, key = self.dummy_sample - residual = 0.1 * sample + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample + output = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, key, **kwargs).prev_sample - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, key = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - num_inference_steps = kwargs.pop("num_inference_steps", None) + outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - sample, key = self.dummy_sample - residual = 0.1 * sample + outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - output_0 = scheduler.step(state, residual, 0, sample, key, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, key, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) + def test_deprecated_kwargs(self): + for scheduler_class in self.scheduler_classes: + has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters + has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) + if has_kwarg_in_model_class and not has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" + " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" + " []`" + ) - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, key = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, key, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, key, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_deprecated_kwargs(self): - for scheduler_class in self.scheduler_classes: - has_kwarg_in_model_class = "kwargs" in inspect.signature(scheduler_class.__init__).parameters - has_deprecated_kwarg = len(scheduler_class._deprecated_kwargs) > 0 - - if has_kwarg_in_model_class and not has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} has `**kwargs` in its __init__ method but has not defined any deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if" - " there are no deprecated arguments or add the deprecated argument with `_deprecated_kwargs =" - " []`" - ) - - if not has_kwarg_in_model_class and has_deprecated_kwarg: - raise ValueError( - f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" - " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" - f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" - " deprecated argument from `_deprecated_kwargs = []`" - ) + if not has_kwarg_in_model_class and has_deprecated_kwarg: + raise ValueError( + f"{scheduler_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated" + " kwargs under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs`" + f" argument to {self.model_class}.__init__ if there are deprecated arguments or remove the" + " deprecated argument from `_deprecated_kwargs = []`" + ) @require_flax class FlaxDDPMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDPMScheduler,) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "variance_type": "fixed_small", - "clip_sample": True, - } - - config.update(**kwargs) - return config - - def test_timesteps(self): - for timesteps in [1, 5, 100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_variance_type(self): - for variance in ["fixed_small", "fixed_large", "other"]: - self.check_over_configs(variance_type=variance) - - def test_clip_sample(self): - for clip_sample in [True, False]: - self.check_over_configs(clip_sample=clip_sample) - - def test_time_indices(self): - for t in [0, 500, 999]: - self.check_over_forward(time_step=t) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_trained_timesteps = len(scheduler) - - model = self.dummy_model() - sample = self.dummy_sample_deter - key1, key2 = random.split(random.PRNGKey(0)) - - for t in reversed(range(num_trained_timesteps)): - # 1. predict noise residual - residual = model(sample, t) - - # 2. predict previous mean of sample x_t-1 - output = scheduler.step(state, residual, t, sample, key1) - pred_prev_sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) - - # if t > 0: - # noise = self.dummy_sample_deter - # variance = scheduler.get_variance(t) ** (0.5) * noise - # - # sample = pred_prev_sample + variance - sample = pred_prev_sample - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 257.28717) < 1.5e-2 - assert abs(result_mean - 0.33500) < 2e-5 - else: - assert abs(result_sum - 257.33148) < 1e-2 - assert abs(result_mean - 0.335057) < 1e-3 + scheduler_classes = (FlaxDDPMScheduler,) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "variance_type": "fixed_small", + "clip_sample": True, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [1, 5, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_variance_type(self): + for variance in ["fixed_small", "fixed_large", "other"]: + self.check_over_configs(variance_type=variance) + + def test_clip_sample(self): + for clip_sample in [True, False]: + self.check_over_configs(clip_sample=clip_sample) + + def test_time_indices(self): + for t in [0, 500, 999]: + self.check_over_forward(time_step=t) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_trained_timesteps = len(scheduler) + + model = self.dummy_model() + sample = self.dummy_sample_deter + key1, key2 = random.split(random.PRNGKey(0)) + + for t in reversed(range(num_trained_timesteps)): + # 1. predict noise residual + residual = model(sample, t) + + # 2. predict previous mean of sample x_t-1 + output = scheduler.step(state, residual, t, sample, key1) + pred_prev_sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) + + # if t > 0: + # noise = self.dummy_sample_deter + # variance = scheduler.get_variance(t) ** (0.5) * noise + # + # sample = pred_prev_sample + variance + sample = pred_prev_sample + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 257.28717) < 1.5e-2 + assert abs(result_mean - 0.33500) < 2e-5 + else: + assert abs(result_sum - 257.33148) < 1e-2 + assert abs(result_mean - 0.335057) < 1e-3 @require_flax class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxDDIMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) + scheduler_classes = (FlaxDDIMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + key1, key2 = random.split(random.PRNGKey(0)) + + num_inference_steps = 10 - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } + model = self.dummy_model() + sample = self.dummy_sample_deter - config.update(**kwargs) - return config + state = scheduler.set_timesteps(state, num_inference_steps) - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - key1, key2 = random.split(random.PRNGKey(0)) + for t in state.timesteps: + residual = model(sample, t) + output = scheduler.step(state, residual, t, sample) + sample = output.prev_sample + state = output.state + key1, key2 = random.split(key2) - num_inference_steps = 10 + return sample - model = self.dummy_model() - sample = self.dummy_sample_deter + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - for t in state.timesteps: - residual = model(sample, t) - output = scheduler.step(state, residual, t, sample) - sample = output.prev_sample - state = output.state - key1, key2 = random.split(key2) + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - return sample + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + num_inference_steps = kwargs.pop("num_inference_steps", None) - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + num_inference_steps = kwargs.pop("num_inference_steps", None) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - num_inference_steps = kwargs.pop("num_inference_steps", None) + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) - output = scheduler.step(state, residual, time_step, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, time_step, sample, **kwargs).prev_sample + num_inference_steps = kwargs.pop("num_inference_steps", None) - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) + sample, _ = self.dummy_sample + residual = 0.1 * sample - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample - output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 500, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() - - def test_steps_trailing(self): - self.check_over_configs(timestep_spacing="trailing") - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(timestep_spacing="trailing") - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([999, 799, 599, 399, 199])).all() - - def test_steps_leading(self): - self.check_over_configs(timestep_spacing="leading") - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(timestep_spacing="leading") - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 5) - assert jnp.equal(state.timesteps, jnp.array([800, 600, 400, 200, 0])).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 10, 49]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): - self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - - def test_variance(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 - assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_sum - 172.0067) < 1e-2 - assert abs(result_mean - 0.223967) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 149.82944) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - else: - assert abs(result_sum - 149.8295) < 1e-2 - assert abs(result_mean - 0.1951) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - pass - # FIXME: both result_sum and result_mean are nan on TPU - # assert jnp.isnan(result_sum) - # assert jnp.isnan(result_mean) - else: - assert abs(result_sum - 149.0784) < 1e-2 - assert abs(result_mean - 0.1941) < 1e-3 - - def test_prediction_type(self): - for prediction_type in ["epsilon", "sample", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step(state, residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 500, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([801, 601, 401, 201, 1])).all() + + def test_steps_trailing(self): + self.check_over_configs(timestep_spacing="trailing") + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(timestep_spacing="trailing") + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([999, 799, 599, 399, 199])).all() + + def test_steps_leading(self): + self.check_over_configs(timestep_spacing="leading") + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(timestep_spacing="leading") + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 5) + assert jnp.equal(state.timesteps, jnp.array([800, 600, 400, 200, 0])).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 10, 49]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_variance(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 420, 400) - 0.14771)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 980, 960) - 0.32460)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 0, 0) - 0.0)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 487, 486) - 0.00979)) < 1e-5 + assert jnp.sum(jnp.abs(scheduler._get_variance(state, 999, 998) - 0.02)) < 1e-5 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum - 172.0067) < 1e-2 + assert abs(result_mean - 0.223967) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 149.82944) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + else: + assert abs(result_sum - 149.8295) < 1e-2 + assert abs(result_mean - 0.1951) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + pass + # FIXME: both result_sum and result_mean are nan on TPU + # assert jnp.isnan(result_sum) + # assert jnp.isnan(result_mean) + else: + assert abs(result_sum - 149.0784) < 1e-2 + assert abs(result_mean - 0.1941) < 1e-3 + + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) @require_flax class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxPNDMScheduler,) - forward_default_kwargs = (("num_inference_steps", 50),) - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - state = state.replace(ets=dummy_past_residuals[:]) - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - # copy over dummy past residuals - new_state = new_state.replace(ets=dummy_past_residuals[:]) - - (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_from_save_pretrained(self): - pass - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - sample, _ = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample, _ = self.dummy_sample - residual = 0.1 * sample - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residuals (must be after setting timesteps) - scheduler.ets = dummy_past_residuals[:] - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - # copy over dummy past residuals - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - - # copy over dummy past residual (must be after setting timesteps) - new_state.replace(ets=dummy_past_residuals[:]) - - output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) - new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) - new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def full_loop(self, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter + scheduler_classes = (FlaxPNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + state = state.replace(ets=dummy_past_residuals[:]) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) + # copy over dummy past residuals + new_state = new_state.replace(ets=dummy_past_residuals[:]) + + (prev_sample, state) = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + (new_prev_sample, new_state) = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(prev_sample - new_prev_sample)) < 1e-5, "Scheduler outputs are not identical" + + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) + + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_save_pretrained(self): + pass + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample, _ = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - for i, t in enumerate(state.prk_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_prk(state, residual, t, sample) + # copy over dummy past residuals (must be after setting timesteps) + scheduler.ets = dummy_past_residuals[:] - for i, t in enumerate(state.plms_timesteps): - residual = model(sample, t) - sample, state = scheduler.step_plms(state, residual, t, sample) + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + # copy over dummy past residuals + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) - return sample + # copy over dummy past residual (must be after setting timesteps) + new_state.replace(ets=dummy_past_residuals[:]) - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) + output, state = scheduler.step_prk(state, residual, time_step, sample, **kwargs) + new_output, new_state = new_scheduler.step_prk(new_state, residual, time_step, sample, **kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + output, _ = scheduler.step_plms(state, residual, time_step, sample, **kwargs) + new_output, _ = new_scheduler.step_plms(new_state, residual, time_step, sample, **kwargs) - sample, _ = self.dummy_sample - residual = 0.1 * sample + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - # copy over dummy past residuals (must be done after set_timesteps) - dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) - state = state.replace(ets=dummy_past_residuals[:]) - - output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) - output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_timesteps(self): - for timesteps in [100, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_steps_offset(self): - for steps_offset in [0, 1]: - self.check_over_configs(steps_offset=steps_offset) - - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(steps_offset=1) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - state = scheduler.set_timesteps(state, 10, shape=()) - assert jnp.equal( - state.timesteps, - jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), - ).all() - - def test_betas(self): - for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): - self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - - def test_schedules(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - - def test_time_indices(self): - for t in [1, 5, 10]: - self.check_over_forward(time_step=t) - - def test_inference_steps(self): - for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward(num_inference_steps=num_inference_steps) - - def test_pow_of_3_inference_steps(self): - # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 - num_inference_steps = 27 - - for scheduler_class in self.scheduler_classes: - sample, _ = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) - - # before power of 3 fix, would error on first step, so we only need to do two - for i, t in enumerate(state.prk_timesteps[:2]): - sample, state = scheduler.step_prk(state, residual, t, sample) - - def test_inference_plms_no_past_residuals(self): - with self.assertRaises(ValueError): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 198.1275) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - else: - assert abs(result_sum - 198.1318) < 1e-2 - assert abs(result_mean - 0.2580) < 1e-3 - - def test_full_loop_with_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 8e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9466) < 1e-2 - assert abs(result_mean - 0.24342) < 1e-3 - - def test_full_loop_with_no_set_alpha_to_one(self): - # We specify different beta, so that the first alpha is 0.99 - sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) - - if jax_device == "tpu": - assert abs(result_sum - 186.83226) < 8e-2 - assert abs(result_mean - 0.24327) < 1e-3 - else: - assert abs(result_sum - 186.9482) < 1e-2 - assert abs(result_mean - 0.2434) < 1e-3 + def full_loop(self, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + for i, t in enumerate(state.prk_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_prk(state, residual, t, sample) + + for i, t in enumerate(state.plms_timesteps): + residual = model(sample, t) + sample, state = scheduler.step_plms(state, residual, t, sample) + + return sample + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + sample, _ = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = jnp.array([residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]) + state = state.replace(ets=dummy_past_residuals[:]) + + output_0, state = scheduler.step_prk(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_prk(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0, state = scheduler.step_plms(state, residual, 0, sample, **kwargs) + output_1, state = scheduler.step_plms(state, residual, 1, sample, **kwargs) + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_steps_offset(self): + for steps_offset in [0, 1]: + self.check_over_configs(steps_offset=steps_offset) + + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(steps_offset=1) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + state = scheduler.set_timesteps(state, 10, shape=()) + assert jnp.equal( + state.timesteps, + jnp.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]), + ).all() + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(num_inference_steps=num_inference_steps) + + def test_pow_of_3_inference_steps(self): + # earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3 + num_inference_steps = 27 + + for scheduler_class in self.scheduler_classes: + sample, _ = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, shape=sample.shape) + + # before power of 3 fix, would error on first step, so we only need to do two + for i, t in enumerate(state.prk_timesteps[:2]): + sample, state = scheduler.step_prk(state, residual, t, sample) + + def test_inference_plms_no_past_residuals(self): + with self.assertRaises(ValueError): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + scheduler.step_plms(state, self.dummy_sample, 1, self.dummy_sample).prev_sample + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 198.1275) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + else: + assert abs(result_sum - 198.1318) < 1e-2 + assert abs(result_mean - 0.2580) < 1e-3 + + def test_full_loop_with_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 8e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9466) < 1e-2 + assert abs(result_mean - 0.24342) < 1e-3 + + def test_full_loop_with_no_set_alpha_to_one(self): + # We specify different beta, so that the first alpha is 0.99 + sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01) + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + if jax_device == "tpu": + assert abs(result_sum - 186.83226) < 8e-2 + assert abs(result_mean - 0.24327) < 1e-3 + else: + assert abs(result_sum - 186.9482) < 1e-2 + assert abs(result_mean - 0.2434) < 1e-3 diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py index 1d23880f..821adcfe 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_rf.py @@ -1,18 +1,18 @@ """ - Copyright 2025 Google LLC +Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import jax.numpy as jnp from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler import os @@ -23,76 +23,81 @@ import numpy as np - class rfTest(unittest.TestCase): - def test_rf_steps(self): - # --- Simulation Parameters --- - latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) - inference_steps_count = 5 # Number of steps for the denoising process - - # --- Run the Simulation --- - max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") - - seed = 42 - device = 'cpu' - max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") - - generator = torch.Generator(device=device).manual_seed(seed) - - # 1. Instantiate the scheduler - config = {'_class_name': 'RectifiedFlowScheduler', '_diffusers_version': '0.25.1', 'num_train_timesteps': 1000, 'shifting': None, 'base_resolution': None, 'sampler': 'LinearQuadratic'} - flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) - - # 2. Create and set initial state for the scheduler - flax_state = flax_scheduler.create_state() - flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) - max_logging.log("\nScheduler initialized.") - max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") - - # 3. Prepare the initial noisy latent sample - # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) - # For simulation, we'll generate it. - - sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) - max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") - - # 4. Simulate the denoising loop - max_logging.log("\nStarting denoising loop:") - for i, t in enumerate(flax_state.timesteps): - max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") - - # Simulate model_output (e.g., noise prediction from a UNet) - model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) - - # Call the scheduler's step function - scheduler_output = flax_scheduler.step( - state=flax_state, - model_output=model_output, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=sample, - return_dict=True # Return a SchedulerOutput dataclass - ) - - sample = scheduler_output.prev_sample # Update the sample for the next step - flax_state = scheduler_output.state # Update the state for the next step - - # Compare with pytorch implementation - base_dir = os.path.dirname(__file__) - ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") - ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") - if os.path.exists(ref_filename): - pt_sample = np.load(ref_filename) - torch.testing.assert_close(np.array(sample), pt_sample) - else: - max_logging.log(f"Warning: Reference file not found: {ref_filename}") - - - max_logging.log("\nDenoising loop completed.") - max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") - max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") - - max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") + def test_rf_steps(self): + # --- Simulation Parameters --- + latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) + inference_steps_count = 5 # Number of steps for the denoising process + + # --- Run the Simulation --- + max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") + + seed = 42 + device = "cpu" + max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") + + generator = torch.Generator(device=device).manual_seed(seed) + + # 1. Instantiate the scheduler + config = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": None, + "base_resolution": None, + "sampler": "LinearQuadratic", + } + flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) + + # 2. Create and set initial state for the scheduler + flax_state = flax_scheduler.create_state() + flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) + max_logging.log("\nScheduler initialized.") + max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") + + # 3. Prepare the initial noisy latent sample + # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) + # For simulation, we'll generate it. + + sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) + max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") + + # 4. Simulate the denoising loop + max_logging.log("\nStarting denoising loop:") + for i, t in enumerate(flax_state.timesteps): + max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") + + # Simulate model_output (e.g., noise prediction from a UNet) + model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) + + # Call the scheduler's step function + scheduler_output = flax_scheduler.step( + state=flax_state, + model_output=model_output, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + + sample = scheduler_output.prev_sample # Update the sample for the next step + flax_state = scheduler_output.state # Update the state for the next step + + # Compare with pytorch implementation + base_dir = os.path.dirname(__file__) + ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") + ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") + if os.path.exists(ref_filename): + pt_sample = np.load(ref_filename) + torch.testing.assert_close(np.array(sample), pt_sample) + else: + max_logging.log(f"Warning: Reference file not found: {ref_filename}") + + max_logging.log("\nDenoising loop completed.") + max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") + max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") + + max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") if __name__ == "__main__": diff --git a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py index d401f54f..657a5bb8 100644 --- a/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py +++ b/src/maxdiffusion/tests/legacy_hf_tests/schedulers/test_scheduler_unipc.py @@ -30,651 +30,629 @@ class FlaxUniPCMultistepSchedulerTest(FlaxSchedulerCommonTest): - scheduler_classes = (FlaxUniPCMultistepScheduler,) - forward_default_kwargs = (("num_inference_steps", 25),) - - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - sample = torch.rand((batch_size, num_channels, height, width)) - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_noise_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = torch.arange(num_elems).flip(-1) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) - - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - height = 8 - width = 8 - - num_elems = batch_size * num_channels * height * width - sample = torch.arange(num_elems) - sample = sample.reshape(num_channels, height, width, batch_size) - sample = sample / num_elems - sample = sample.permute(3, 0, 1, 2) - - jax_sample= jnp.asarray(sample) - return jax_sample - - def get_scheduler_config(self, **kwargs): - config = { - "num_train_timesteps": 1000, - "beta_start": 0.0001, - "beta_end": 0.02, - "beta_schedule": "linear", - "solver_order": 2, - "solver_type": "bh2", - "final_sigmas_type": "sigma_min", - } - - config.update(**kwargs) - return config - - def check_over_configs(self, time_step=0, **config): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample = self.dummy_sample - residual = 0.1 * sample - dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config(**config) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - state = scheduler.set_timesteps( - state, num_inference_steps, sample.shape - ) - new_state = new_scheduler.set_timesteps( - new_state, num_inference_steps, sample.shape - ) - # copy over dummy past residuals - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - # Copy over dummy past residuals to new_state as well - new_state = new_state.replace(model_outputs=initial_model_outputs) - - - output_sample, output_state = sample, state - new_output_sample, new_output_state = sample, new_state - # Need to iterate through the steps as UniPC maintains history over steps - # The loop for solver_order + 1 steps is crucial for UniPC's history logic. - for i in range(time_step, time_step + scheduler.config.solver_order + 1): - # Ensure time_step + i is within the bounds of timesteps - if i >= len(output_state.timesteps): - break - t = output_state.timesteps[i] - step_output = scheduler.step( - state=output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - output_sample = step_output.prev_sample - output_state = step_output.state - - new_step_output = new_scheduler.step( - state=new_output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=new_output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - new_output_sample = new_step_output.prev_sample - new_output_state = new_step_output.state - - self.assertTrue( - jnp.allclose(output_sample, new_output_sample, atol=1e-5), - "Scheduler outputs are not identical", - ) - # Also assert that states are identical - self.assertEqual(output_state.step_index, new_output_state.step_index) - self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) - self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) - # Comparing model_outputs (history) directly: - if output_state.model_outputs is not None and new_output_state.model_outputs is not None: - for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): - self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") - - def check_over_forward(self, time_step=0, **forward_kwargs): - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - sample = self.dummy_sample - residual = 0.1 * sample - dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - state = scheduler.set_timesteps( - state, num_inference_steps, sample.shape - ) - - # copy over dummy past residuals - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - - # What is this doing? - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): - new_state = new_scheduler.set_timesteps( - new_state, num_inference_steps, sample.shape - ) - # Copy over dummy past residuals to new_state as well - new_state = new_state.replace(model_outputs=initial_model_outputs) - - - output_sample, output_state = sample, state - new_output_sample, new_output_state = sample, new_state - - # Need to iterate through the steps as UniPC maintains history over steps - # The loop for solver_order + 1 steps is crucial for UniPC's history logic. - for i in range(time_step, time_step + scheduler.config.solver_order + 1): - # Ensure time_step + i is within the bounds of timesteps - if i >= len(output_state.timesteps): - break - - t = output_state.timesteps[i] - - step_output = scheduler.step( - state=output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - **kwargs, - ) - output_sample = step_output.prev_sample - output_state = step_output.state - - new_step_output = new_scheduler.step( - state=new_output_state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=new_output_sample, - return_dict=True, # Return a SchedulerOutput dataclass - **kwargs, - ) - new_output_sample = new_step_output.prev_sample - new_output_state = new_step_output.state - - self.assertTrue( - jnp.allclose(output_sample, new_output_sample, atol=1e-5), - "Scheduler outputs are not identical", - ) - # Also assert that states are identical - self.assertEqual(output_state.step_index, new_output_state.step_index) - self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) - self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) - # Comparing model_outputs (history) directly: - if output_state.model_outputs is not None and new_output_state.model_outputs is not None: - for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): - self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") - - - def full_loop(self, scheduler=None, **config): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config(**config) - if scheduler is None: - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - else: - state = scheduler.create_state() # Ensure state is fresh for the loop - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter + scheduler_classes = (FlaxUniPCMultistepScheduler,) + forward_default_kwargs = (("num_inference_steps", 25),) + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + jax_sample = jnp.asarray(sample) + return jax_sample + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "solver_order": 2, + "solver_type": "bh2", + "final_sigmas_type": "sigma_min", + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + t = output_state.timesteps[i] + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_model_outputs = [residual + 0.2, residual + 0.15, residual + 0.10] + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # copy over dummy past residuals + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + + # What is this doing? + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(new_scheduler, "set_timesteps"): + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + # Copy over dummy past residuals to new_state as well + new_state = new_state.replace(model_outputs=initial_model_outputs) + + output_sample, output_state = sample, state + new_output_sample, new_output_state = sample, new_state + + # Need to iterate through the steps as UniPC maintains history over steps + # The loop for solver_order + 1 steps is crucial for UniPC's history logic. + for i in range(time_step, time_step + scheduler.config.solver_order + 1): + # Ensure time_step + i is within the bounds of timesteps + if i >= len(output_state.timesteps): + break + + t = output_state.timesteps[i] + + step_output = scheduler.step( + state=output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + output_sample = step_output.prev_sample + output_state = step_output.state + + new_step_output = new_scheduler.step( + state=new_output_state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=new_output_sample, + return_dict=True, # Return a SchedulerOutput dataclass + **kwargs, + ) + new_output_sample = new_step_output.prev_sample + new_output_state = new_step_output.state + + self.assertTrue( + jnp.allclose(output_sample, new_output_sample, atol=1e-5), + "Scheduler outputs are not identical", + ) + # Also assert that states are identical + self.assertEqual(output_state.step_index, new_output_state.step_index) + self.assertTrue(jnp.allclose(output_state.timesteps, new_output_state.timesteps)) + self.assertTrue(jnp.allclose(output_state.sigmas, new_output_state.sigmas, atol=1e-5)) + # Comparing model_outputs (history) directly: + if output_state.model_outputs is not None and new_output_state.model_outputs is not None: + for out1, out2 in zip(output_state.model_outputs, new_output_state.model_outputs): + self.assertTrue(jnp.allclose(out1, out2, atol=1e-5), "Model outputs history not identical") + + def full_loop(self, scheduler=None, **config): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + if scheduler is None: + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + else: + state = scheduler.create_state() # Ensure state is fresh for the loop + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): + residual = model(sample, t) + + # scheduler.step in common test receives state, residual, t, sample + step_output = scheduler.step( + state=state, + model_output=residual, + timestep=t, # Pass the current timestep from the scheduler's sequence + sample=sample, + return_dict=True, # Return a SchedulerOutput dataclass + ) + sample = step_output.prev_sample + state = step_output.state # Update state for next iteration + + return sample + + def test_from_save_pretrained(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - for i, t in enumerate(state.timesteps): - residual = model(sample, t) + output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample + new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - # scheduler.step in common test receives state, residual, t, sample - step_output = scheduler.step( - state=state, - model_output=residual, - timestep=t, # Pass the current timestep from the scheduler's sequence - sample=sample, - return_dict=True, # Return a SchedulerOutput dataclass - ) - sample = step_output.prev_sample - state = step_output.state # Update state for next iteration - - return sample - - def test_from_save_pretrained(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - sample = self.dummy_sample - residual = 0.1 * sample - - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - with tempfile.TemporaryDirectory() as tmpdirname: - scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - output = scheduler.step(state, residual, 1, sample, **kwargs).prev_sample - new_output = new_scheduler.step(new_state, residual, 1, sample, **kwargs).prev_sample - - assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - - def test_step_shape(self): - kwargs = dict(self.forward_default_kwargs) - - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() # Create initial state - - sample = self.dummy_sample # Get sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif ( - num_inference_steps is not None - and not hasattr(scheduler, "set_timesteps") - ): - kwargs["num_inference_steps"] = num_inference_steps - - # Copy over dummy past residuals (must be done after set_timesteps) - dummy_past_model_outputs = [ - 0.2 * sample, - 0.15 * sample, - 0.10 * sample, - ] - initial_model_outputs = jnp.stack(dummy_past_model_outputs[ - : scheduler.config.solver_order - ]) - state = state.replace(model_outputs=initial_model_outputs) - - time_step_0 = state.timesteps[5] - time_step_1 = state.timesteps[6] - - output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample - output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample - - self.assertEqual(output_0.shape, sample.shape) - self.assertEqual(output_0.shape, output_1.shape) - - def test_scheduler_outputs_equivalence(self): - def set_nan_tensor_to_zero(t): - return t.at[t != t].set(0) - - def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), - msg=( - "Tuple and dict output are not equal. Difference:" - f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" - f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" - f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." - ), - ) - - kwargs = dict(self.forward_default_kwargs) - num_inference_steps = kwargs.pop("num_inference_steps", None) - - for scheduler_class in self.scheduler_classes: - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - - sample = self.dummy_sample - residual = 0.1 * sample - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - - if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): - state = scheduler.create_state() - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): - kwargs["num_inference_steps"] = num_inference_steps - - outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) - recursive_check(outputs_tuple[0], outputs_dict.prev_sample) - - def test_switch(self): - # make sure that iterating over schedulers with same config names gives same results - # for defaults - scheduler_config = self.get_scheduler_config() - scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) - sample_1 = self.full_loop(scheduler=scheduler_1) - result_mean_1 = jnp.mean(jnp.abs(sample_1)) - - assert abs(result_mean_1.item() - 0.2464) < 1e-3 - - scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance - sample_2 = self.full_loop(scheduler=scheduler_2) - result_mean_2 = jnp.mean(jnp.abs(sample_2)) - - self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency - - assert abs(result_mean_2.item() - 0.2464) < 1e-3 - - def test_timesteps(self): - for timesteps in [25, 50, 100, 999, 1000]: - self.check_over_configs(num_train_timesteps=timesteps) - - def test_thresholding(self): - self.check_over_configs(thresholding=False) - for order in [1, 2, 3]: - for solver_type in ["bh1", "bh2"]: - for threshold in [0.5, 1.0, 2.0]: - for prediction_type in ["epsilon", "sample"]: - with self.assertRaises(NotImplementedError): - self.check_over_configs( - thresholding=True, - prediction_type=prediction_type, - sample_max_value=threshold, - solver_order=order, - solver_type=solver_type, - ) - - def test_prediction_type(self): - for prediction_type in ["epsilon", "v_prediction"]: - self.check_over_configs(prediction_type=prediction_type) - - def test_rescale_betas_zero_snr(self): - for rescale_zero_terminal_snr in [True, False]: - self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) - - def test_solver_order_and_type(self): - for solver_type in ["bh1", "bh2"]: - for order in [1, 2, 3]: - for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - ) - sample = self.full_loop( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - ) - assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" - - - def test_lower_order_final(self): - self.check_over_configs(lower_order_final=True) - self.check_over_configs(lower_order_final=False) - - def test_inference_steps(self): - for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: - self.check_over_forward(time_step = 0, num_inference_steps=num_inference_steps) - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2464) < 1e-3 - - def test_full_loop_with_karras(self): - # sample = self.full_loop(use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.2925) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(use_karras_sigmas=True) - - def test_full_loop_with_v_prediction(self): - sample = self.full_loop(prediction_type="v_prediction") - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.1014) < 1e-3 - - def test_full_loop_with_karras_and_v_prediction(self): - # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.1966) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - - def test_fp16_support(self): - scheduler_class = self.scheduler_classes[0] - for order in [1, 2, 3]: - for solver_type in ["bh1", "bh2"]: - for prediction_type in ["epsilon", "sample", "v_prediction"]: - scheduler_config = self.get_scheduler_config( - thresholding=False, - dynamic_thresholding_ratio=0, - prediction_type=prediction_type, - solver_order=order, - solver_type=solver_type, - ) - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - model = self.dummy_model() - sample = self.dummy_sample_deter.astype(jnp.bfloat16) - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - for i, t in enumerate(state.timesteps): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state - # sample is casted to fp32 inside step and output should be fp32. - self.assertEqual(sample.dtype, jnp.float32) - - def test_full_loop_with_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() + assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() # Create initial state - num_inference_steps = 10 - t_start_index = 8 + sample = self.dummy_sample # Get sample + residual = 0.1 * sample - model = self.dummy_model() - sample = self.dummy_sample_deter + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + # Copy over dummy past residuals (must be done after set_timesteps) + dummy_past_model_outputs = [ + 0.2 * sample, + 0.15 * sample, + 0.10 * sample, + ] + initial_model_outputs = jnp.stack(dummy_past_model_outputs[: scheduler.config.solver_order]) + state = state.replace(model_outputs=initial_model_outputs) + + time_step_0 = state.timesteps[5] + time_step_1 = state.timesteps[6] + + output_0 = scheduler.step(state, residual, time_step_0, sample).prev_sample + output_1 = scheduler.step(state, residual, time_step_1, sample).prev_sample + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + return t.at[t != t].set(0) + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + jnp.allclose(set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {jnp.max(jnp.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {jnp.isnan(tuple_object).any()} and `inf`: {jnp.isinf(tuple_object)}. Dict has" + f" `nan`: {jnp.isnan(dict_object).any()} and `inf`: {jnp.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps - # add noise - noise = self.dummy_noise_deter - timesteps_for_noise = state.timesteps[t_start_index :] - sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + outputs_dict = scheduler.step(state, residual, 0, sample, **kwargs) - for i, t in enumerate(timesteps_for_noise): + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + state = scheduler.create_state() + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(state, residual, 0, sample, return_dict=False, **kwargs) + recursive_check(outputs_tuple[0], outputs_dict.prev_sample) + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler_config = self.get_scheduler_config() + scheduler_1 = FlaxUniPCMultistepScheduler(**scheduler_config) + sample_1 = self.full_loop(scheduler=scheduler_1) + result_mean_1 = jnp.mean(jnp.abs(sample_1)) + + assert abs(result_mean_1.item() - 0.2464) < 1e-3 + + scheduler_2 = FlaxUniPCMultistepScheduler(**scheduler_config) # New instance + sample_2 = self.full_loop(scheduler=scheduler_2) + result_mean_2 = jnp.mean(jnp.abs(sample_2)) + + self.assertTrue(jnp.allclose(result_mean_1, result_mean_2, atol=1e-3)) # Check consistency + + assert abs(result_mean_2.item() - 0.2464) < 1e-3 + + def test_timesteps(self): + for timesteps in [25, 50, 100, 999, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_thresholding(self): + self.check_over_configs(thresholding=False) + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for threshold in [0.5, 1.0, 2.0]: + for prediction_type in ["epsilon", "sample"]: + with self.assertRaises(NotImplementedError): + self.check_over_configs( + thresholding=True, + prediction_type=prediction_type, + sample_max_value=threshold, + solver_order=order, + solver_type=solver_type, + ) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_rescale_betas_zero_snr(self): + for rescale_zero_terminal_snr in [True, False]: + self.check_over_configs(rescale_zero_terminal_snr=rescale_zero_terminal_snr) + + def test_solver_order_and_type(self): + for solver_type in ["bh1", "bh2"]: + for order in [1, 2, 3]: + for prediction_type in ["epsilon", "sample"]: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + sample = self.full_loop( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + ) + assert not jnp.any(jnp.isnan(sample)), "Samples have nan numbers" + + def test_lower_order_final(self): + self.check_over_configs(lower_order_final=True) + self.check_over_configs(lower_order_final=False) + + def test_inference_steps(self): + for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: + self.check_over_forward(time_step=0, num_inference_steps=num_inference_steps) + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2464) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2925) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1966) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + for order in [1, 2, 3]: + for solver_type in ["bh1", "bh2"]: + for prediction_type in ["epsilon", "sample", "v_prediction"]: + scheduler_config = self.get_scheduler_config( + thresholding=False, + dynamic_thresholding_ratio=0, + prediction_type=prediction_type, + solver_order=order, + solver_type=solver_type, + ) + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.astype(jnp.bfloat16) + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + for i, t in enumerate(state.timesteps): residual = model(sample, t) step_output = scheduler.step(state, residual, t, sample) sample = step_output.prev_sample state = step_output.state + # sample is casted to fp32 inside step and output should be fp32. + self.assertEqual(sample.dtype, jnp.float32) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() - assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" - assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" + num_inference_steps = 10 + t_start_index = 8 + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) -class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): - @property - def dummy_sample(self): - batch_size = 4 - num_channels = 3 - width = 8 - - torch_sample = torch.rand((batch_size, num_channels, width)) - jax_sample= jnp.asarray(torch_sample) - return jax_sample - - @property - def dummy_noise_deter(self): - batch_size = 4 - num_channels = 3 - width = 8 - - num_elems = batch_size * num_channels * width - sample = torch.arange(num_elems).flip(-1) - sample = sample.reshape(num_channels, width, batch_size) - sample = sample / num_elems - sample = sample.permute(2, 0, 1) - - jax_sample= jnp.asarray(sample) - return jax_sample - - @property - def dummy_sample_deter(self): - batch_size = 4 - num_channels = 3 - width = 8 - - num_elems = batch_size * num_channels * width - sample = torch.arange(num_elems) - sample = sample.reshape(num_channels, width, batch_size) - sample = sample / num_elems - sample = sample.permute(2, 0, 1) - jax_sample= jnp.asarray(sample) - return jax_sample - - def test_switch(self): - # make sure that iterating over schedulers with same config names gives same results - # for defaults - scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) - sample = self.full_loop(scheduler=scheduler) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) - scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) - - sample = self.full_loop(scheduler=scheduler) - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - def test_full_loop_no_noise(self): - sample = self.full_loop() - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.2441) < 1e-3 - - def test_full_loop_with_karras(self): - # sample = self.full_loop(use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.2898) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(use_karras_sigmas=True) - - - def test_full_loop_with_v_prediction(self): - sample = self.full_loop(prediction_type="v_prediction") - result_mean = jnp.mean(jnp.abs(sample)) - - assert abs(result_mean.item() - 0.1014) < 1e-3 - - def test_full_loop_with_karras_and_v_prediction(self): - # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - # result_mean = jnp.mean(jnp.abs(sample)) - - # assert abs(result_mean.item() - 0.1944) < 1e-3 - with self.assertRaises(NotImplementedError): - self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) - - def test_full_loop_with_noise(self): - scheduler_class = self.scheduler_classes[0] - scheduler_config = self.get_scheduler_config() - scheduler = scheduler_class(**scheduler_config) - state = scheduler.create_state() - - num_inference_steps = 10 - t_start_index = 8 - - model = self.dummy_model() - sample = self.dummy_sample_deter - state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) - - # add noise - noise = self.dummy_noise_deter - timesteps_for_noise = state.timesteps[t_start_index :] - sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index:] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) - for i, t in enumerate(timesteps_for_noise): - residual = model(sample, t) - step_output = scheduler.step(state, residual, t, sample) - sample = step_output.prev_sample - state = step_output.state + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) - result_sum = jnp.sum(jnp.abs(sample)) - result_mean = jnp.mean(jnp.abs(sample)) + assert abs(result_sum.item() - 315.5757) < 1e-2, f" expected result sum 315.5757, but get {result_sum}" + assert abs(result_mean.item() - 0.4109) < 1e-3, f" expected result mean 0.4109, but get {result_mean}" - assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" - assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" - def test_beta_sigmas(self): - # self.check_over_configs(use_beta_sigmas=True) - with self.assertRaises(NotImplementedError): - self.full_loop(use_beta_sigmas=True) +class FlaxUniPCMultistepScheduler1DTest(FlaxUniPCMultistepSchedulerTest): - def test_exponential_sigmas(self): - #self.check_over_configs(use_exponential_sigmas=True) - with self.assertRaises(NotImplementedError): - self.full_loop(use_exponential_sigmas=True) + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + width = 8 + + torch_sample = torch.rand((batch_size, num_channels, width)) + jax_sample = jnp.asarray(torch_sample) + return jax_sample + + @property + def dummy_noise_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems).flip(-1) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + + jax_sample = jnp.asarray(sample) + return jax_sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + width = 8 + + num_elems = batch_size * num_channels * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, width, batch_size) + sample = sample / num_elems + sample = sample.permute(2, 0, 1) + jax_sample = jnp.asarray(sample) + return jax_sample + + def test_switch(self): + # make sure that iterating over schedulers with same config names gives same results + # for defaults + scheduler = FlaxUniPCMultistepScheduler(**self.get_scheduler_config()) + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + scheduler = FlaxDPMSolverMultistepScheduler.from_config(scheduler.config) + scheduler = FlaxUniPCMultistepScheduler.from_config(scheduler.config) + + sample = self.full_loop(scheduler=scheduler) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_no_noise(self): + sample = self.full_loop() + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.2441) < 1e-3 + + def test_full_loop_with_karras(self): + # sample = self.full_loop(use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.2898) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(use_karras_sigmas=True) + + def test_full_loop_with_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction") + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_mean.item() - 0.1014) < 1e-3 + + def test_full_loop_with_karras_and_v_prediction(self): + # sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + # result_mean = jnp.mean(jnp.abs(sample)) + + # assert abs(result_mean.item() - 0.1944) < 1e-3 + with self.assertRaises(NotImplementedError): + self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + + def test_full_loop_with_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + state = scheduler.create_state() + + num_inference_steps = 10 + t_start_index = 8 + + model = self.dummy_model() + sample = self.dummy_sample_deter + state = scheduler.set_timesteps(state, num_inference_steps, sample.shape) + + # add noise + noise = self.dummy_noise_deter + timesteps_for_noise = state.timesteps[t_start_index:] + sample = scheduler.add_noise(state, sample, noise, timesteps_for_noise[:1]) + + for i, t in enumerate(timesteps_for_noise): + residual = model(sample, t) + step_output = scheduler.step(state, residual, t, sample) + sample = step_output.prev_sample + state = step_output.state + + result_sum = jnp.sum(jnp.abs(sample)) + result_mean = jnp.mean(jnp.abs(sample)) + + assert abs(result_sum.item() - 39.0870) < 1e-2, f" expected result sum 39.0870, but get {result_sum}" + assert abs(result_mean.item() - 0.4072) < 1e-3, f" expected result mean 0.4072, but get {result_mean}" + + def test_beta_sigmas(self): + # self.check_over_configs(use_beta_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_beta_sigmas=True) + + def test_exponential_sigmas(self): + # self.check_over_configs(use_exponential_sigmas=True) + with self.assertRaises(NotImplementedError): + self.full_loop(use_exponential_sigmas=True)