diff --git a/src/agents/_data.py b/src/agents/_data.py index 8c4be40c..fb10e03e 100644 --- a/src/agents/_data.py +++ b/src/agents/_data.py @@ -158,12 +158,14 @@ def get_data(self, id: any) -> knext.Table: f"No data item found with id: {id}. Available ids: {[data.id for data in self._data]}" ) - def get_last_tables(self, num_tables: int) -> list[knext.Table]: + def get_last_tables( + self, num_tables: int, fill_missing: bool = True + ) -> list[knext.Table]: """Returns the last `num_tables` tables added to the registry.""" if num_tables <= 0: return [] tables = [data_item.data for data_item in self._data[-num_tables:]] - if len(tables) < num_tables: + if fill_missing and len(tables) < num_tables: empty_table = _empty_table() tables = tables + [empty_table] * (num_tables - len(tables)) return tables diff --git a/src/agents/base.py b/src/agents/base.py index 83a37604..d7326230 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -1272,9 +1272,7 @@ def execute( tools_table: Optional[knext.Table], input_tables: list[knext.Table], ): - import pandas as pd from ._data_service import DataRegistry - import pyarrow as pa view_data = ctx._get_view_data() num_data_outputs = ctx.get_connected_output_port_numbers()[2] @@ -1285,36 +1283,20 @@ def execute( data_registry = DataRegistry.load( view_data["data"]["data_registry"], view_data["ports_for_ids"] ) + data_outputs = data_registry.get_last_tables( + num_data_outputs, fill_missing=False + ) return ( combined_tools_workflow, conversation_table, - data_registry.get_last_tables(num_data_outputs), + data_outputs + + [knext.InactivePort] * (num_data_outputs - len(data_outputs)), ) else: - message_type = _message_type() - columns = [ - util.OutputColumn( - self.conversation_column_name, - message_type, - message_type.to_pyarrow(), - ) - ] - if self.errors.has_error_column: - columns.append( - util.OutputColumn( - self.errors.error_column_name, - knext.string(), - pa.string(), - ) - ) - conversation_table = util.create_empty_table(None, columns) return ( combined_tools_workflow, - conversation_table, - [ - knext.Table.from_pandas(pd.DataFrame()) # empty table - ] - * num_data_outputs, + knext.InactivePort, + [knext.InactivePort] * num_data_outputs, ) def get_data_service( diff --git a/tests/test_agent_chat_widget.py b/tests/test_agent_chat_widget.py new file mode 100644 index 00000000..bc225a37 --- /dev/null +++ b/tests/test_agent_chat_widget.py @@ -0,0 +1,135 @@ +import unittest +from unittest.mock import patch +import sys +import types + +import knime.extension as knext +import knime.extension.nodes as kn + + +class _MockBackend: + def register_port_type(self, name, object_class, spec_class, id=None): + if id is None: + id = f"test.{object_class.__module__}.{object_class.__qualname__}" + return kn.PortType(id, name, object_class, spec_class) + + +knime_types_module = types.ModuleType("knime.types") +knime_types_tool_module = types.ModuleType("knime.types.tool") +knime_types_tool_module.WorkflowTool = type("WorkflowTool", (), {}) +knime_types_message_module = types.ModuleType("knime.types.message") +knime_types_message_module.MessageValue = type("MessageValue", (), {}) + + +def setUpModule(): + """Set up patched KNIME backend and types for this test module.""" + global _backend_patcher + global _logical_patcher + global _sysmodules_patcher + global _inactiveport_present + global _inactiveport_original + global AgentChatWidget + + _backend_patcher = patch.object(kn, "_backend", _MockBackend()) + _backend_patcher.start() + + _logical_patcher = patch.object( + knext, "logical", lambda value_type: knext.string() + ) + _logical_patcher.start() + + _inactiveport_present = hasattr(knext, "InactivePort") + _inactiveport_original = getattr(knext, "InactivePort", None) + knext.InactivePort = object() + + _sysmodules_patcher = patch.dict( + sys.modules, + { + "knime.types": knime_types_module, + "knime.types.tool": knime_types_tool_module, + "knime.types.message": knime_types_message_module, + }, + clear=False, + ) + _sysmodules_patcher.start() + + from agents.base import AgentChatWidget as _AgentChatWidget + + AgentChatWidget = _AgentChatWidget + + +def tearDownModule(): + """Tear down patched KNIME backend and types for this test module.""" + _backend_patcher.stop() + _logical_patcher.stop() + + if _inactiveport_present: + knext.InactivePort = _inactiveport_original + else: + delattr(knext, "InactivePort") + + _sysmodules_patcher.stop() +class _MockContext: + def __init__(self, view_data, num_data_outputs, combined_tools_workflow): + self._view_data_value = view_data + self._num_data_outputs = num_data_outputs + self._combined_tools_workflow_value = combined_tools_workflow + + def _get_view_data(self): + return self._view_data_value + + def get_connected_output_port_numbers(self): + return [1, 1, self._num_data_outputs] + + def _get_combined_tools_workflow(self): + return self._combined_tools_workflow_value + + +class AgentChatWidgetOutputTest(unittest.TestCase): + def test_execute_returns_inactive_ports_when_no_view_data_exists(self): + node = AgentChatWidget() + ctx = _MockContext( + view_data=None, + num_data_outputs=2, + combined_tools_workflow="workflow-port", + ) + + combined_tools_workflow, conversation_output, data_outputs = node.execute( + ctx, None, None, [] + ) + + self.assertEqual("workflow-port", combined_tools_workflow) + self.assertIs(knext.InactivePort, conversation_output) + self.assertEqual([knext.InactivePort, knext.InactivePort], data_outputs) + + def test_execute_marks_missing_data_outputs_inactive(self): + node = AgentChatWidget() + conversation_table = object() + data_table = object() + ctx = _MockContext( + view_data={ + "ports": [conversation_table], + "data": {"data_registry": {"ids": [], "ports": []}}, + "ports_for_ids": [], + }, + num_data_outputs=2, + combined_tools_workflow="workflow-port", + ) + + class _MockDataRegistry: + def get_last_tables(self, num_tables, fill_missing=True): + self.args = (num_tables, fill_missing) + return [data_table] + + registry = _MockDataRegistry() + with patch("agents._data_service.DataRegistry.load", return_value=registry): + combined_tools_workflow, conversation_output, data_outputs = node.execute( + ctx, None, None, [] + ) + + self.assertEqual("workflow-port", combined_tools_workflow) + self.assertIs(conversation_table, conversation_output) + self.assertEqual((2, False), registry.args) + self.assertEqual(2, len(data_outputs)) + self.assertIs(data_table, data_outputs[0]) + self.assertIs(knext.InactivePort, data_outputs[1])