Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions comfy_extras/nodes_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MathExpressionNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int]),
input=io.MultiType.Input("value", [io.Float, io.Int, io.Boolean]),
names=list(string.ascii_lowercase),
min=1,
)
Expand All @@ -82,6 +82,7 @@ def define_schema(cls) -> io.Schema:
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
io.Boolean.Output(display_name="BOOL"),
],
)

Expand All @@ -97,7 +98,7 @@ def execute(

result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if isinstance(result, bool) or not isinstance(result, (int, float)):
if not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
Expand All @@ -106,7 +107,7 @@ def execute(
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result))
return io.NodeOutput(float(result), int(result), bool(result))


class MathExtension(ComfyExtension):
Expand Down
8 changes: 5 additions & 3 deletions tests-unit/comfy_extras_test/nodes_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ def test_undefined_function_raises(self):
with pytest.raises(Exception, match="not defined"):
self._exec("str(a)", a=42)

def test_boolean_result_raises(self):
with pytest.raises(ValueError, match="got bool"):
self._exec("a > b", a=5, b=3)
def test_boolean_result(self):
result = self._exec("a > b", a=5, b=3)
assert result[2] is True
result = self._exec("a > b", a=3, b=5)
assert result[2] is False

def test_empty_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):
Expand Down
Loading