Skip to content

Commit 5aa5ccc

Browse files
authored
Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (Comfy-Org#13802)
* model_management: disable non-dynamic smart memory Disable smart memory outright for non dynamic models. This is a minor step towards deprecation of --disable-dynamic-vram and the legacy ModelPatcher. This is needed for estimate-free model development, where new models can opt-out of supplying a memory estimate and not have to worry about hard VRAM allocations due to legacy non-dynamic model patchers This is also a general stability increase for a lot of stray use cases where estimates may still be off and going forward we are not going to accurately maintain such estimates. * pinned_memory: implement with aimdo growable buffer Use a single growable buffer so we can do threaded pre-warming on pinned memory. * mm: use aimdo to do transfer from disk to pin Aimdo implements a faster threaded loader. * Add stream host pin buffer for AIMDO casts Introduce per-offload-stream HostBuffer reuse for pinned staging, include it in cast buffer reset synchronization. Defer actual casts that go via this pin path to a separate pass such that the buffer can be allocated monolithically (to avoid cudaHostRegister thrash). * remove old pin path * Implement JIT pinned memory pressure Replace the predictive pin pressure mechanism with JIT PIN memory pressure. * LowVRAMPatch: change to two-phase visit * lora: re-implement as inplace swiss-army-knife operation * prepare for multiple pin sets * implement pinned loras * requirements: comfy-aimdo 0.4.0 * ops: remove unused arg This was defeatured in aimdo iteration * ops: sync the CPU with only the offload stream activity This was syncing with the offload stream which itself is synced with the compute stream, so this was syncing CPU with compute transitively. Define the event to sync it more gently. * pins: implement freeing intermediate for pinned memory Pinning is more important than inactive intermediates and the stream pin buffer is more important than even active intermediates. * execution: implement pin eviction on RAM presure Add back proper pin freeing on RAM pressure * implement pin registration swaps Uncap the windows pins from 50% by extending the pool and have a pressure mechanism to move the pin reservations om demand. This unfortunately implies a GPU sync to do the freeing so significant hysterisis needs to be added to consolidate these pressure events. * cli_args/execution: Implement lower background cache-ram threshold Limit the amount of RAM background intermediates can use, so that switching workflows doesn't degrade performance too much. * make default * bump aimdo * model-patcher: force-cast tiny weights Flux 2 gets crazy stalls due to a mix of tiny and giant weights creating lopsided steam buffer rotations which creates stalls. * ops: refactor in prep for chunking * mm: delegate pin-on-the-way to aimdo Aimdo is able to chunk and slice this on the way for better CPU->GPU overlap. The main advantage is the ability to shorten the bus contention window between previous weight transfer and the next weights vbar fault. * bump aimdo * pinning updates * specify hostbuf max allocation size There a signs of virtual memory exhaustion on some linux systems when throwing 128GB for every little piece. Pass the actual to save aimdo from over-estimates * tests: update execution tests for caching The default caching changed to ram-cache so update these tests accordingly. Remove the LRU 0 test as this also falls through to RAM cache.
1 parent 4d6a058 commit 5aa5ccc

14 files changed

Lines changed: 408 additions & 219 deletions

comfy/cli_args.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,11 @@ def from_string(cls, value: str):
110110

111111
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
112112

113-
CACHE_RAM_AUTO_GB = -1.0
114-
115113
cache_group = parser.add_mutually_exclusive_group()
114+
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).")
116115
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
117116
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
118117
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
119-
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
120118

121119
attn_group = parser.add_mutually_exclusive_group()
122120
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@@ -245,6 +243,9 @@ def is_valid_directory(path: str) -> str:
245243
else:
246244
args = parser.parse_args([])
247245

246+
if args.cache_ram is not None and len(args.cache_ram) > 2:
247+
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
248+
248249
if args.windows_standalone_build:
249250
args.auto_launch = True
250251

comfy/lora.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
484484

485485
return weight
486486

487-
def prefetch_prepared_value(value, allocate_buffer, stream):
487+
def prefetch_prepared_value(value, counter, destination, stream, copy):
488488
if isinstance(value, torch.Tensor):
489-
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
490-
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
489+
size = comfy.memory_management.vram_aligned_size(value)
490+
offset = counter[0]
491+
counter[0] += size
492+
if destination is None:
493+
return value
494+
495+
dest = destination[offset:offset + size]
496+
if copy:
497+
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
491498
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
492499
elif isinstance(value, weight_adapter.WeightAdapterBase):
493-
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
500+
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
494501
elif isinstance(value, tuple):
495-
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
502+
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
496503
elif isinstance(value, list):
497-
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
504+
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
498505

499506
return value

comfy/memory_management.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,25 @@ class TensorFileSlice(NamedTuple):
1515
size: int
1616

1717

18-
def read_tensor_file_slice_into(tensor, destination):
18+
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
1919

2020
if isinstance(tensor, QuantizedTensor):
2121
if not isinstance(destination, QuantizedTensor):
2222
return False
2323
if tensor._layout_cls != destination._layout_cls:
2424
return False
2525

26-
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
26+
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
27+
destination2=(destination2._qdata if destination2 is not None else None)):
2728
return False
2829

2930
dst_orig_dtype = destination._params.orig_dtype
3031
destination._params.copy_from(tensor._params, non_blocking=False)
3132
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
33+
if destination2 is not None:
34+
dst_orig_dtype = destination2._params.orig_dtype
35+
destination2._params.copy_from(destination._params, non_blocking=True)
36+
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
3237
return True
3338

3439
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
@@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination):
4853
if info.size == 0:
4954
return True
5055

56+
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
57+
if hostbuf is not None:
58+
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
59+
device_ptr = destination2.data_ptr() if destination2 is not None else 0
60+
hostbuf.read_file_slice(file_obj, info.offset, info.size,
61+
offset=destination.data_ptr() - hostbuf.get_raw_address(),
62+
stream=stream_ptr,
63+
device_ptr=device_ptr,
64+
device=None if destination2 is None else destination2.device.index)
65+
return True
66+
5167
buf_type = ctypes.c_ubyte * info.size
5268
view = memoryview(buf_type.from_address(destination.data_ptr()))
5369

@@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom):
151167
extra_ram_release_callback = callback
152168
RAM_CACHE_HEADROOM = max(0, int(headroom))
153169

154-
def extra_ram_release(target):
170+
def extra_ram_release(target, free_active=False):
155171
if extra_ram_release_callback is None:
156172
return 0
157-
return extra_ram_release_callback(target)
173+
return extra_ram_release_callback(target, free_active=free_active)

0 commit comments

Comments
 (0)