Skip to content

Commit 35086ac

Browse files
authored
[core] support device type device_maps to work with offloading. (#12811)
* support device type device_maps to work with offloading. * add tests. * fix tests * skip tests where it's not supported. * empty * up * up * fix allegro.
1 parent e390646 commit 35086ac

10 files changed

Lines changed: 83 additions & 18 deletions

File tree

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112
for library in LOADABLE_CLASSES:
113113
LIBRARIES.append(library)
114114

115-
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
115+
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
116116

117117
logger = logging.get_logger(__name__)
118118

@@ -468,8 +468,7 @@ def module_is_offloaded(module):
468468
pipeline_is_sequentially_offloaded = any(
469469
module_is_sequentially_offloaded(module) for _, module in self.components.items()
470470
)
471-
472-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
471+
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
473472
if is_pipeline_device_mapped:
474473
raise ValueError(
475474
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1188,7 +1187,7 @@ def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.devi
11881187
"""
11891188
self._maybe_raise_error_if_group_offload_active(raise_error=True)
11901189

1191-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1190+
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
11921191
if is_pipeline_device_mapped:
11931192
raise ValueError(
11941193
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1312,7 +1311,7 @@ def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch
13121311
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
13131312
self.remove_all_hooks()
13141313

1315-
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
1314+
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
13161315
if is_pipeline_device_mapped:
13171316
raise ValueError(
13181317
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2228,6 +2227,21 @@ def _maybe_raise_error_if_group_offload_active(
22282227
return True
22292228
return False
22302229

2230+
def _is_pipeline_device_mapped(self):
2231+
# We support passing `device_map="cuda"`, for example. This is helpful, in case
2232+
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
2233+
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
2234+
device_map = self.hf_device_map
2235+
is_device_type_map = False
2236+
if isinstance(device_map, str):
2237+
try:
2238+
torch.device(device_map)
2239+
is_device_type_map = True
2240+
except RuntimeError:
2241+
pass
2242+
2243+
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
2244+
22312245

22322246
class StableDiffusionMixin:
22332247
r"""

tests/models/testing_utils/quantization.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,21 @@ def test_bnb_training(self):
628628
"""Test that quantized models can be used for training with adapters."""
629629
self._test_quantization_training(BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"])
630630

631+
@pytest.mark.parametrize(
632+
"config_name",
633+
list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
634+
ids=list(BitsAndBytesConfigMixin.BNB_CONFIGS.keys()),
635+
)
636+
def test_cpu_device_map(self, config_name):
637+
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS[config_name]
638+
model_quantized = self._create_quantized_model(config_kwargs, device_map="cpu")
639+
640+
assert hasattr(model_quantized, "hf_device_map"), "Model should have hf_device_map attribute"
641+
assert model_quantized.hf_device_map is not None, "hf_device_map should not be None"
642+
assert model_quantized.device == torch.device("cpu"), (
643+
f"Model should be on CPU, but is on {model_quantized.device}"
644+
)
645+
631646

632647
@is_quantization
633648
@is_quanto

tests/pipelines/allegro/test_allegro.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def test_save_load_local(self):
158158
def test_save_load_optional_components(self):
159159
pass
160160

161+
@unittest.skip("Decoding without tiling is not yet implemented")
162+
def test_pipeline_with_accelerator_device_map(self):
163+
pass
164+
161165
def test_inference(self):
162166
device = "cpu"
163167

tests/pipelines/kandinsky/test_kandinsky_combined.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@
3434

3535
class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
3636
pipeline_class = KandinskyCombinedPipeline
37-
params = [
38-
"prompt",
39-
]
37+
params = ["prompt"]
4038
batch_params = ["prompt", "negative_prompt"]
4139
required_optional_params = [
4240
"generator",
@@ -148,6 +146,10 @@ def test_float16_inference(self):
148146
def test_dict_tuple_outputs_equivalent(self):
149147
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
150148

149+
@unittest.skip("Test not supported.")
150+
def test_pipeline_with_accelerator_device_map(self):
151+
pass
152+
151153

152154
class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
153155
pipeline_class = KandinskyImg2ImgCombinedPipeline
@@ -264,6 +266,10 @@ def test_dict_tuple_outputs_equivalent(self):
264266
def test_save_load_optional_components(self):
265267
super().test_save_load_optional_components(expected_max_difference=5e-4)
266268

269+
@unittest.skip("Test not supported.")
270+
def test_pipeline_with_accelerator_device_map(self):
271+
pass
272+
267273

268274
class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
269275
pipeline_class = KandinskyInpaintCombinedPipeline
@@ -384,3 +390,7 @@ def test_save_load_optional_components(self):
384390

385391
def test_save_load_local(self):
386392
super().test_save_load_local(expected_max_difference=5e-3)
393+
394+
@unittest.skip("Test not supported.")
395+
def test_pipeline_with_accelerator_device_map(self):
396+
pass

tests/pipelines/kandinsky2_2/test_kandinsky_combined.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@
3636

3737
class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
3838
pipeline_class = KandinskyV22CombinedPipeline
39-
params = [
40-
"prompt",
41-
]
39+
params = ["prompt"]
4240
batch_params = ["prompt", "negative_prompt"]
4341
required_optional_params = [
4442
"generator",
@@ -70,12 +68,7 @@ def get_dummy_components(self):
7068
def get_dummy_inputs(self, device, seed=0):
7169
prior_dummy = PriorDummies()
7270
inputs = prior_dummy.get_dummy_inputs(device=device, seed=seed)
73-
inputs.update(
74-
{
75-
"height": 64,
76-
"width": 64,
77-
}
78-
)
71+
inputs.update({"height": 64, "width": 64})
7972
return inputs
8073

8174
def test_kandinsky(self):
@@ -155,12 +148,18 @@ def test_save_load_local(self):
155148
def test_save_load_optional_components(self):
156149
super().test_save_load_optional_components(expected_max_difference=5e-3)
157150

151+
@unittest.skip("Test not supported.")
158152
def test_callback_inputs(self):
159153
pass
160154

155+
@unittest.skip("Test not supported.")
161156
def test_callback_cfg(self):
162157
pass
163158

159+
@unittest.skip("Test not supported.")
160+
def test_pipeline_with_accelerator_device_map(self):
161+
pass
162+
164163

165164
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
166165
pipeline_class = KandinskyV22Img2ImgCombinedPipeline
@@ -279,12 +278,18 @@ def test_save_load_optional_components(self):
279278
def save_load_local(self):
280279
super().test_save_load_local(expected_max_difference=5e-3)
281280

281+
@unittest.skip("Test not supported.")
282282
def test_callback_inputs(self):
283283
pass
284284

285+
@unittest.skip("Test not supported.")
285286
def test_callback_cfg(self):
286287
pass
287288

289+
@unittest.skip("Test not supported.")
290+
def test_pipeline_with_accelerator_device_map(self):
291+
pass
292+
288293

289294
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
290295
pipeline_class = KandinskyV22InpaintCombinedPipeline
@@ -411,3 +416,7 @@ def test_callback_inputs(self):
411416

412417
def test_callback_cfg(self):
413418
pass
419+
420+
@unittest.skip("`device_map` is not yet supported for connected pipelines.")
421+
def test_pipeline_with_accelerator_device_map(self):
422+
pass

tests/pipelines/kandinsky2_2/test_kandinsky_inpaint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,9 @@ def callback_inputs_test(pipe, i, t, callback_kwargs):
296296
output = pipe(**inputs)[0]
297297
assert output.abs().sum() == 0
298298

299+
def test_pipeline_with_accelerator_device_map(self):
300+
super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3)
301+
299302

300303
@slow
301304
@require_torch_accelerator

tests/pipelines/kandinsky3/test_kandinsky3_img2img.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def test_inference_batch_single_identical(self):
194194
def test_save_load_dduf(self):
195195
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
196196

197+
def test_pipeline_with_accelerator_device_map(self):
198+
super().test_pipeline_with_accelerator_device_map(expected_max_difference=5e-3)
199+
197200

198201
@slow
199202
@require_torch_accelerator

tests/pipelines/test_pipelines_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2355,7 +2355,6 @@ def test_torch_dtype_dict(self):
23552355
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
23562356
)
23572357

2358-
@require_torch_accelerator
23592358
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4):
23602359
components = self.get_dummy_components()
23612360
pipe = self.pipeline_class(**components)

tests/pipelines/visualcloze/test_pipeline_visualcloze_combined.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,7 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
342342
self.assertLess(
343343
max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
344344
)
345+
346+
@unittest.skip("Test not supported.")
347+
def test_pipeline_with_accelerator_device_map(self):
348+
pass

tests/pipelines/visualcloze/test_pipeline_visualcloze_generation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,7 @@ def test_save_load_float16(self, expected_max_diff=1e-2):
310310
@unittest.skip("Skipped due to missing layout_prompt. Needs further investigation.")
311311
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=0.0001, rtol=0.0001):
312312
pass
313+
314+
@unittest.skip("Needs to be revisited later.")
315+
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=0.0001):
316+
pass

0 commit comments

Comments
 (0)