|
112 | 112 | for library in LOADABLE_CLASSES: |
113 | 113 | LIBRARIES.append(library) |
114 | 114 |
|
115 | | -SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()] |
| 115 | +SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"] |
116 | 116 |
|
117 | 117 | logger = logging.get_logger(__name__) |
118 | 118 |
|
@@ -468,8 +468,7 @@ def module_is_offloaded(module): |
468 | 468 | pipeline_is_sequentially_offloaded = any( |
469 | 469 | module_is_sequentially_offloaded(module) for _, module in self.components.items() |
470 | 470 | ) |
471 | | - |
472 | | - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 |
| 471 | + is_pipeline_device_mapped = self._is_pipeline_device_mapped() |
473 | 472 | if is_pipeline_device_mapped: |
474 | 473 | raise ValueError( |
475 | 474 | "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." |
@@ -1188,7 +1187,7 @@ def enable_model_cpu_offload(self, gpu_id: int | None = None, device: torch.devi |
1188 | 1187 | """ |
1189 | 1188 | self._maybe_raise_error_if_group_offload_active(raise_error=True) |
1190 | 1189 |
|
1191 | | - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 |
| 1190 | + is_pipeline_device_mapped = self._is_pipeline_device_mapped() |
1192 | 1191 | if is_pipeline_device_mapped: |
1193 | 1192 | raise ValueError( |
1194 | 1193 | "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`." |
@@ -1312,7 +1311,7 @@ def enable_sequential_cpu_offload(self, gpu_id: int | None = None, device: torch |
1312 | 1311 | raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") |
1313 | 1312 | self.remove_all_hooks() |
1314 | 1313 |
|
1315 | | - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 |
| 1314 | + is_pipeline_device_mapped = self._is_pipeline_device_mapped() |
1316 | 1315 | if is_pipeline_device_mapped: |
1317 | 1316 | raise ValueError( |
1318 | 1317 | "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`." |
@@ -2228,6 +2227,21 @@ def _maybe_raise_error_if_group_offload_active( |
2228 | 2227 | return True |
2229 | 2228 | return False |
2230 | 2229 |
|
| 2230 | + def _is_pipeline_device_mapped(self): |
| 2231 | + # We support passing `device_map="cuda"`, for example. This is helpful, in case |
| 2232 | + # users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable |
| 2233 | + # in limited VRAM environments because quantized models often initialize directly on the accelerator. |
| 2234 | + device_map = self.hf_device_map |
| 2235 | + is_device_type_map = False |
| 2236 | + if isinstance(device_map, str): |
| 2237 | + try: |
| 2238 | + torch.device(device_map) |
| 2239 | + is_device_type_map = True |
| 2240 | + except RuntimeError: |
| 2241 | + pass |
| 2242 | + |
| 2243 | + return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1 |
| 2244 | + |
2231 | 2245 |
|
2232 | 2246 | class StableDiffusionMixin: |
2233 | 2247 | r""" |
|
0 commit comments