Skip to content
189 changes: 150 additions & 39 deletions invokeai/app/services/shared/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,23 @@ def is_any(t: Any) -> bool:
return t == Any or Any in get_args(t)


def extract_collection_item_types(t: Any) -> set[Any]:
"""Extracts list item types from a collection annotation, including unions containing list branches."""
if is_any(t):
return {Any}

if get_origin(t) is list:
return {arg for arg in get_args(t) if arg != NoneType}

item_types: set[Any] = set()
for arg in get_args(t):
if is_any(arg):
item_types.add(Any)
elif get_origin(arg) is list:
item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType)
return item_types


def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type or not to_type:
return False
Expand Down Expand Up @@ -280,7 +297,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
)


@invocation("collect", version="1.0.0")
@invocation("collect", version="1.1.0")
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""

Expand All @@ -292,7 +309,10 @@ class CollectInvocation(BaseInvocation):
input=Input.Connection,
)
collection: list[Any] = InputField(
description="The collection, will be provided on execution", default=[], ui_hidden=True
description="An optional collection to append to",
default=[],
ui_type=UIType._Collection,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
Expand Down Expand Up @@ -520,7 +540,9 @@ def _validate_edge(self, edge: Edge):

# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
if len(input_edges) > 0 and (
not isinstance(to_node, CollectInvocation) or edge.destination.field != ITEM_FIELD
):
raise InvalidEdgeError(f"Edge already exists ({edge})")

# Validate that no cycles would be created
Expand All @@ -546,8 +568,10 @@ def _validate_edge(self, edge: Edge):
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")

# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD:
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
if isinstance(to_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD):
err = self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field
)
if err is not None:
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")

Expand Down Expand Up @@ -676,76 +700,152 @@ def _is_iterator_connection_valid(

# Collector input type must match all iterator output types
if isinstance(input_node, CollectInvocation):
collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD)
if len(collector_inputs) == 0:
return "Iterator input collector must have at least one item input edge"

# Traverse the graph to find the first collector input edge. Collectors validate that their collection
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
first_collector_input_edge = collector_inputs[0]
first_collector_input_type = get_output_field_type(
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
)
resolved_collector_type = (
first_collector_input_type
if get_origin(first_collector_input_type) is None
else get_args(first_collector_input_type)
)
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
input_root_type = self._get_collector_input_root_type(input_node.id)
if input_root_type is None:
return "Iterator input collector must have at least one item or collection input edge"
if not all((are_connection_types_compatible(input_root_type, t) for t in output_field_types)):
return "Iterator collection type must match all iterator output types"

return None

def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]:
"""Resolves possible item types for a collector's inputs, recursively following chained collectors."""
visited = visited or set()
if node_id in visited:
return set()
visited.add(node_id)

input_types: set[Any] = set()

for edge in self._get_input_edges(node_id, ITEM_FIELD):
input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field)
resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
input_types.update(t for t in resolved_types if t != NoneType)

for edge in self._get_input_edges(node_id, COLLECTION_FIELD):
source_node = self.get_node(edge.source.node_id)
if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD:
input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy()))
continue

input_field_type = get_output_field_type(source_node, edge.source.field)
input_types.update(extract_collection_item_types(input_field_type))

return input_types

def _get_collector_input_root_type(self, node_id: str) -> Any | None:
input_types = self._resolve_collector_input_types(node_id)
non_any_input_types = {t for t in input_types if t != Any}
if len(non_any_input_types) == 0 and Any in input_types:
return Any
if len(non_any_input_types) == 0:
return None

type_tree = nx.DiGraph()
type_tree.add_nodes_from(non_any_input_types)
type_tree.add_edges_from([e for e in itertools.permutations(non_any_input_types, 2) if issubclass(e[1], e[0])])
type_degrees = type_tree.in_degree(type_tree.nodes)
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
if len(root_types) != 1:
return Any
return root_types[0]

def _is_collector_connection_valid(
self,
node_id: str,
new_input: Optional[EdgeConnection] = None,
new_input_field: Optional[str] = None,
new_output: Optional[EdgeConnection] = None,
) -> str | None:
inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]

if new_input is not None:
inputs.append(new_input)
field = new_input_field or ITEM_FIELD
if field == ITEM_FIELD:
item_inputs.append(new_input)
elif field == COLLECTION_FIELD:
collection_inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)

# Get input and output fields (the fields linked to the iterator's input/output)
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
if len(item_inputs) == 0 and len(collection_inputs) == 0:
return "Collector must have at least one item or collection input edge"

# Get input and output fields (the fields linked to the collector's input/output)
item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs]
collection_input_field_types = [
get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs
]
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]

if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)):
return "Collector collection input must be a collection"

# Validate that all inputs are derived from or match a single type
input_field_types = {
resolved_type
for input_field_type in input_field_types
for input_field_type in item_input_field_types
for resolved_type in (
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
)
if resolved_type != NoneType
} # Get unique types

for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False):
source_node = self.get_node(input_conn.node_id)
if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD:
input_field_types.update(self._resolve_collector_input_types(source_node.id))
continue
input_field_types.update(extract_collection_item_types(input_field_type))

non_any_input_field_types = {t for t in input_field_types if t != Any}
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_tree.add_nodes_from(non_any_input_field_types)
type_tree.add_edges_from(
[e for e in itertools.permutations(non_any_input_field_types, 2) if issubclass(e[1], e[0])]
)
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
if len(root_types) > 1:
return "Collector input collection items must be of a single type"

# Get the input root type
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Get the input root type (if known)
input_root_type = root_types[0] if len(root_types) == 1 else None

# Verify that all outputs are lists
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
return "Collector output must connect to a collection input"

# Verify that all outputs match the input type (are a base class or the same class)
if not all(
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
):
if input_root_type is not None:
if not all(
is_any(t)
or is_union_subtype(input_root_type, get_args(t)[0])
or issubclass(input_root_type, get_args(t)[0])
for t in output_field_types
):
return "Collector outputs must connect to a collection input with a matching type"
elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types):
return "Collector outputs must connect to a collection input with a matching type"

# If this collector outputs to another collector's collection input, validate against the downstream
# collector's resolved input type (if available).
for output in outputs:
output_node = self.get_node(output.node_id)
if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD:
continue
output_root_type = self._get_collector_input_root_type(output_node.id)
if output_root_type is None:
continue
if input_root_type is None:
if output_root_type != Any:
return "Collector outputs must connect to a collection input with a matching type"
continue
if not are_connection_types_compatible(input_root_type, output_root_type):
return "Collector outputs must connect to a collection input with a matching type"

return None

def nx_graph(self) -> nx.DiGraph:
Expand Down Expand Up @@ -1211,8 +1311,19 @@ def _prepare_inputs(self, node: BaseInvocation):
if isinstance(node, CollectInvocation):
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))

output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges]
collection_edges = [e for e in input_edges if e.destination.field == COLLECTION_FIELD]
collection_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))

output_collection = []
for edge in collection_edges:
source_value = copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
if isinstance(source_value, list):
output_collection.extend(source_value)
else:
output_collection.append(source_value)
output_collection.extend(
copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges
)
node.collection = output_collection
else:
for edge in input_edges:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,14 @@ describe(getCollectItemType.name, () => {
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
expect(result).toBeNull();
});

it('should return the upstream collect item type for chained collects', () => {
const n1 = buildNode(collect);
const n2 = buildNode(collect);
const n3 = buildNode(add);
const e1 = buildEdge(n3.id, 'value', n1.id, 'item');
const e2 = buildEdge(n1.id, 'collection', n2.id, 'collection');
const result = getCollectItemType(templates, [n1, n2, n3], [e1, e2], n2.id);
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE', batch: false });
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@ import type { Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';

const toItemType = (fieldType: FieldType): FieldType | null => {
if (fieldType.name === 'CollectionField') {
return null;
}
if (fieldType.cardinality === 'COLLECTION' || fieldType.cardinality === 'SINGLE_OR_COLLECTION') {
return { ...fieldType, cardinality: 'SINGLE' };
}
return fieldType;
};

/**
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
Expand All @@ -18,21 +28,56 @@ export const getCollectItemType = (
edges: AnyEdge[],
nodeId: string
): FieldType | null => {
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
if (!firstEdgeToCollect?.sourceHandle) {
return null;
}
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
if (!fieldTemplate) {
const getCollectItemTypeInternal = (currentNodeId: string, visited: Set<string>): FieldType | null => {
if (visited.has(currentNodeId)) {
return null;
}
visited.add(currentNodeId);

const firstItemEdgeToCollect = edges.find((edge) => edge.target === currentNodeId && edge.targetHandle === 'item');
if (firstItemEdgeToCollect?.sourceHandle) {
const node = nodes.find((n) => n.id === firstItemEdgeToCollect.source);
if (!node) {
return null;
}
const template = templates[node.data.type];
if (!template) {
return null;
}
const fieldTemplate = template.outputs[firstItemEdgeToCollect.sourceHandle];
if (!fieldTemplate) {
return null;
}
return toItemType(fieldTemplate.type);
}

const firstCollectionEdgeToCollect = edges.find(
(edge) => edge.target === currentNodeId && edge.targetHandle === 'collection'
);
if (!firstCollectionEdgeToCollect?.sourceHandle) {
return null;
}
const sourceNode = nodes.find((n) => n.id === firstCollectionEdgeToCollect.source);
if (!sourceNode) {
return null;
}
if (sourceNode.data.type === 'collect' && firstCollectionEdgeToCollect.sourceHandle === 'collection') {
return getCollectItemTypeInternal(sourceNode.id, visited);
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return null;
}
const sourceFieldTemplate = sourceTemplate.outputs[firstCollectionEdgeToCollect.sourceHandle];
if (!sourceFieldTemplate) {
return null;
}
return toItemType(sourceFieldTemplate.type);
};

const itemType = getCollectItemTypeInternal(nodeId, new Set());
if (!itemType) {
return null;
}
return fieldTemplate.type;
return itemType;
};
Loading
Loading