Skip to content

Commit 2327fa1

Browse files
rattus128guill
andauthored
execution: Add anti-cycle validation (Comfy-Org#13169)
Currently if the graph contains a cycle, the just inifitiate recursions, hits a catch all then throws a generic error against the output node that seeded the validation. Instead, fail the offending cycling mode chain and handlng it as an error in its own right. Co-authored-by: guill <jacob.e.segal@gmail.com>
1 parent 084e08c commit 2327fa1

1 file changed

Lines changed: 32 additions & 6 deletions

File tree

execution.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -811,11 +811,30 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
811811
self._notify_prompt_lifecycle("end", prompt_id)
812812

813813

814-
async def validate_inputs(prompt_id, prompt, item, validated):
814+
async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
815+
if visiting is None:
816+
visiting = []
817+
815818
unique_id = item
816819
if unique_id in validated:
817820
return validated[unique_id]
818821

822+
if unique_id in visiting:
823+
cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id]
824+
cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
825+
cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes)
826+
for node_id in cycle_nodes:
827+
validated[node_id] = (False, [{
828+
"type": "dependency_cycle",
829+
"message": "Dependency cycle detected",
830+
"details": cycle_path,
831+
"extra_info": {
832+
"node_id": node_id,
833+
"cycle_nodes": cycle_nodes,
834+
}
835+
}], node_id)
836+
return validated[unique_id]
837+
819838
inputs = prompt[unique_id]['inputs']
820839
class_type = prompt[unique_id]['class_type']
821840
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
899918
errors.append(error)
900919
continue
901920
try:
902-
r = await validate_inputs(prompt_id, prompt, o_id, validated)
921+
visiting.append(unique_id)
922+
try:
923+
r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting)
924+
finally:
925+
visiting.pop()
903926
if r[0] is False:
904927
# `r` will be set in `validated[o_id]` already
905928
valid = False
@@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
10481071
errors.append(error)
10491072
continue
10501073

1051-
if len(errors) > 0 or valid is not True:
1052-
ret = (False, errors, unique_id)
1053-
else:
1054-
ret = (True, [], unique_id)
1074+
ret = validated.get(unique_id, (True, [], unique_id))
1075+
# Recursive cycle detection may have already populated an error on us. Join it.
1076+
ret = (
1077+
ret[0] and valid is True and not errors,
1078+
ret[1] + [error for error in errors if error not in ret[1]],
1079+
unique_id,
1080+
)
10551081

10561082
validated[unique_id] = ret
10571083
return ret

0 commit comments

Comments
 (0)