From 735a0465e5daf1f77909b553b02a9d16d1671be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 18 Mar 2026 02:20:49 +0200 Subject: [PATCH 1/4] Inplace VAE output processing to reduce peak RAM consumption. (#13028) --- comfy/sd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd.py b/comfy/sd.py index 4d427bb9aca5..652e76d3ef3b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -455,7 +455,7 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None) self.output_channels = 3 self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 - self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False From 68d542cc0602132d3d2fe624ee7077e44b0fb0ab Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:46:22 -0700 Subject: [PATCH 2/4] Fix case where pixel space VAE could cause issues. (#13030) --- comfy/sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 652e76d3ef3b..df0c4d1d1fdc 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -952,8 +952,8 @@ def decode(self, samples_in, vae_options={}): batch_number = max(1, batch_number) for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype())) + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) if pixel_samples is None: pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) pixel_samples[x:x+batch_number] = out From cad24ce26278a72095d33a2b4391572573201542 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 17 Mar 2026 17:59:10 -0700 Subject: [PATCH 3/4] cascade: remove dead weight init code (#13026) This weight init process is fully shadowed be the weight load and doesnt work in dynamic_vram were the weight allocation is deferred. --- comfy/ldm/cascade/stage_a.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 145e6e69a7c5..e4e30cacdbed 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -136,16 +136,7 @@ def __init__(self, c, c_hidden): ops.Linear(c_hidden, c), ) - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False) def _norm(self, x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) From b941913f1d2d11dc69c098a375309b13c13bca23 Mon Sep 17 00:00:00 2001 From: Anton Bukov Date: Wed, 18 Mar 2026 05:21:32 +0400 Subject: [PATCH 4/4] fix: run text encoders on MPS GPU instead of CPU for Apple Silicon (#12809) On Apple Silicon, `vram_state` is set to `VRAMState.SHARED` because CPU and GPU share unified memory. However, `text_encoder_device()` only checked for `HIGH_VRAM` and `NORMAL_VRAM`, causing all text encoders to fall back to CPU on MPS devices. Adding `VRAMState.SHARED` to the condition allows non-quantized text encoders (e.g. bf16 Gemma 3 12B) to run on the MPS GPU, providing significant speedup for text encoding and prompt generation. Note: quantized models (fp4/fp8) that use float8_e4m3fn internally will still fall back to CPU via the `supports_cast()` check in `CLIP.__init__()`, since MPS does not support fp8 dtypes. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2c250dacc178..5f2e6ef67f1e 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1003,7 +1003,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: