diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py index 6ae919a795ab..12d3a01abf92 100644 --- a/comfy/ldm/sam3/detector.py +++ b/comfy/ldm/sam3/detector.py @@ -54,7 +54,7 @@ def forward(self, q_input, k_input=None, v_input=None, mask=None): if mask is not None and mask.ndim == 2: mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast dtype = q.dtype # manual_cast may produce mixed dtypes - out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask) + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False) return self.out_proj(out) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py index 272781d45774..75cb457cff53 100644 --- a/comfy/ldm/sam3/sam.py +++ b/comfy/ldm/sam3/sam.py @@ -40,7 +40,7 @@ def forward(self, q, k, v): q = self.q_proj(q) k = self.k_proj(k) v = self.v_proj(v) - return self.out_proj(optimized_attention(q, k, v, self.num_heads)) + return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) class TwoWayAttentionBlock(nn.Module): @@ -179,7 +179,7 @@ def forward(self, x, freqs_cis=None): q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) if self.use_rope and freqs_cis is not None: q, k = apply_rope(q, k, freqs_cis) - return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True)) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False)) class Block(nn.Module): diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py index 6ff6369d1339..8f7481003cf0 100644 --- a/comfy/ldm/sam3/tracker.py +++ b/comfy/ldm/sam3/tracker.py @@ -364,7 +364,7 @@ def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0): v = self.v_proj(v) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) - out = optimized_attention(q, k, v, self.num_heads) + out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False) return self.out_proj(out) @@ -657,7 +657,7 @@ def forward(self, image, x, memory_image, memory, memory_image_pos=None, v = self.self_attn_v_proj(normed) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) - x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) # Decoupled cross-attention: fuse image and memory projections normed = self.norm2(x) @@ -668,7 +668,7 @@ def forward(self, image, x, memory_image, memory, memory_image_pos=None, v = self.cross_attn_v_proj(memory) if rope is not None: q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) - x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) # FFN x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d7c2e874470f..19d8a387f07b 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,6 +1,7 @@ import nodes import node_helpers import torch +import torchaudio import comfy.model_management import comfy.model_sampling import comfy.samplers @@ -711,7 +712,14 @@ def define_schema(cls) -> io.Schema: @classmethod def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: # Encode reference audio to latents and patchify - audio_latents = audio_vae.encode(reference_audio) + sample_rate = reference_audio["sample_rate"] + vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate) + else: + waveform = reference_audio["waveform"] + + audio_latents = audio_vae.encode(waveform.movedim(1, -1)) b, c, t, f = audio_latents.shape ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) ref_audio = {"tokens": ref_tokens} diff --git a/execution.py b/execution.py index 5e02dffb204f..e15eb4bda008 100644 --- a/execution.py +++ b/execution.py @@ -811,11 +811,30 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= self._notify_prompt_lifecycle("end", prompt_id) -async def validate_inputs(prompt_id, prompt, item, validated): +async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): + if visiting is None: + visiting = [] + unique_id = item if unique_id in validated: return validated[unique_id] + if unique_id in visiting: + cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) + cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) + for node_id in cycle_nodes: + validated[node_id] = (False, [{ + "type": "dependency_cycle", + "message": "Dependency cycle detected", + "details": cycle_path, + "extra_info": { + "node_id": node_id, + "cycle_nodes": cycle_nodes, + } + }], node_id) + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue try: - r = await validate_inputs(prompt_id, prompt, o_id, validated) + visiting.append(unique_id) + try: + r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting) + finally: + visiting.pop() if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if len(errors) > 0 or valid is not True: - ret = (False, errors, unique_id) - else: - ret = (True, [], unique_id) + ret = validated.get(unique_id, (True, [], unique_id)) + # Recursive cycle detection may have already populated an error on us. Join it. + ret = ( + ret[0] and valid is True and not errors, + ret[1] + [error for error in errors if error not in ret[1]], + unique_id, + ) validated[unique_id] = ret return ret diff --git a/requirements.txt b/requirements.txt index 419124f482ab..7a2e4e0a2d7d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.12 +comfy-aimdo==0.2.14 requests simpleeval>=1.0.0 blake3