Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/maxdiffusion/tests/legacy_hf_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@


def pytest_addoption(parser):
from maxdiffusion.utils.testing_utils import pytest_addoption_shared
from maxdiffusion.utils.testing_utils import pytest_addoption_shared

pytest_addoption_shared(parser)
pytest_addoption_shared(parser)


def pytest_terminal_summary(terminalreporter):
from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main
from maxdiffusion.utils.testing_utils import pytest_terminal_summary_main

make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
pytest_terminal_summary_main(terminalreporter, id=make_reports)
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
pytest_terminal_summary_main(terminalreporter, id=make_reports)
15 changes: 15 additions & 0 deletions src/maxdiffusion/tests/legacy_hf_tests/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import inspect

from maxdiffusion.utils import is_flax_available
from maxdiffusion.utils.testing_utils import require_flax


if is_flax_available():
import jax


@require_flax
class FlaxModelTesterMixin:

def test_output(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)

output = model.apply(variables, inputs_dict["sample"])

if isinstance(output, dict):
output = output.sample

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)

model = self.model_class(**init_dict)
variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"])
jax.lax.stop_gradient(variables)

output = model.apply(variables, inputs_dict["sample"])

if isinstance(output, dict):
output = output.sample

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_deprecated_kwargs(self):
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

if has_kwarg_in_model_class and not has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} has `**kwargs` in its __init__ method but has not defined any deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either remove `**kwargs` if there are"
" no deprecated arguments or add the deprecated argument with `_deprecated_kwargs ="
" [<deprecated_argument>]`"
)

if not has_kwarg_in_model_class and has_deprecated_kwarg:
raise ValueError(
f"{self.model_class} doesn't have `**kwargs` in its __init__ method but has defined deprecated kwargs"
" under the `_deprecated_kwargs` class attribute. Make sure to either add the `**kwargs` argument to"
f" {self.model_class}.__init__ if there are deprecated arguments or remove the deprecated argument"
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import gc
import unittest

from maxdiffusion import FlaxUNet2DConditionModel
from maxdiffusion.utils import is_flax_available
from maxdiffusion.utils.testing_utils import load_hf_numpy, require_flax, slow
from parameterized import parameterized


if is_flax_available():
import jax
import jax.numpy as jnp


@slow
@require_flax
class FlaxUNet2DConditionModelIntegrationTests(unittest.TestCase):

def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()

def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
image = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return image

def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
dtype = jnp.bfloat16 if fp16 else jnp.float32
revision = "bf16" if fp16 else None

model, params = FlaxUNet2DConditionModel.from_pretrained(model_id, subfolder="unet", dtype=dtype, revision=revision)
return model, params

def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = jnp.bfloat16 if fp16 else jnp.float32
hidden_states = jnp.array(load_hf_numpy(self.get_file_format(seed, shape)), dtype=dtype)
return hidden_states

@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
def test_compvis_sd_v1_4_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)

sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample

assert sample.shape == latents.shape

output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)

# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, in the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)

@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
def test_stabilityai_sd_v2_flax_vs_torch_fp16(self, seed, timestep, expected_slice):
model, params = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)

sample = model.apply(
{"params": params},
latents,
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=encoder_hidden_states,
).sample

assert sample.shape == latents.shape

output_slice = jnp.asarray(jax.device_get((sample[-1, -2:, -2:, :2].flatten())), dtype=jnp.float32)
expected_output_slice = jnp.array(expected_slice, dtype=jnp.float32)

# Found torch (float16) and flax (bfloat16) outputs to be within this tolerance, on the same hardware
assert jnp.allclose(output_slice, expected_output_slice, atol=1e-2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import unittest

from maxdiffusion import FlaxAutoencoderKL
from maxdiffusion.utils import is_flax_available
from maxdiffusion.utils.testing_utils import require_flax

from .test_modeling_common_flax import FlaxModelTesterMixin


if is_flax_available():
import jax


@require_flax
class FlaxAutoencoderKLTests(FlaxModelTesterMixin, unittest.TestCase):
model_class = FlaxAutoencoderKL

@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)

prng_key = jax.random.PRNGKey(0)
image = jax.random.uniform(prng_key, ((batch_size, num_channels) + sizes))

return {"sample": image, "prng_key": prng_key}

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
"latent_channels": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
15 changes: 15 additions & 0 deletions src/maxdiffusion/tests/legacy_hf_tests/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
Loading
Loading