diff --git a/docs/examples/as_generic_chat_history.py b/docs/examples/as_generic_chat_history.py new file mode 100644 index 000000000..12d731ca8 --- /dev/null +++ b/docs/examples/as_generic_chat_history.py @@ -0,0 +1,85 @@ +# pytest: unit +"""Convert a heterogeneous context to a generic chat history. + +The as_generic_chat_history() function converts any Context into a list of +Messages, gracefully handling unknown component types by converting them to +strings. This is useful for working with mixed-type contexts or when you need +a more flexible interface than as_chat_history(). +""" + +from mellea.core import CBlock, ModelOutputThunk +from mellea.stdlib.components import Message, as_generic_chat_history +from mellea.stdlib.context import ChatContext + + +def basic_example() -> list[Message]: + """Convert a standard Message-based context to chat history.""" + ctx = ChatContext() + ctx = ctx.add(Message("user", "What is 2+2?")) + ctx = ctx.add(Message("assistant", "2+2 equals 4.")) + + history = as_generic_chat_history(ctx) + assert len(history) == 2 + assert history[0].content == "What is 2+2?" + assert history[1].content == "2+2 equals 4." + return history + + +def with_heterogeneous_components() -> list[Message]: + """Handle mixed component types gracefully. + + Unlike as_chat_history(), as_generic_chat_history() can handle any + component type by converting unknown types to strings. + """ + ctx = ChatContext() + ctx = ctx.add(Message("user", "Summarize this")) + ctx = ctx.add(CBlock("Some inline content to process")) + mot = ModelOutputThunk(value="The summary is...") + ctx = ctx.add(mot) + + history = as_generic_chat_history(ctx) + assert len(history) == 3 + assert history[0].role == "user" + assert history[1].role == "user" # CBlock defaults to 'user' + assert history[2].role == "assistant" # MOT defaults to 'assistant' + return history + + +def with_custom_formatter() -> list[Message]: + """Use a custom formatter for ModelOutputThunk with unparsed content. + + You can provide a formatter function to customize how unparsed outputs + or other unknown types are converted to strings. + """ + + def my_formatter(obj: object) -> str: + return f"[Formatted: {type(obj).__name__}]" + + ctx = ChatContext() + ctx = ctx.add(Message("user", "Process this")) + # Add a ModelOutputThunk with a non-Message parsed_repr + mot = ModelOutputThunk(value="raw data") + mot.parsed_repr = {"type": "dict", "data": "structured"} + ctx = ctx.add(mot) + + history = as_generic_chat_history(ctx, formatter=my_formatter) + assert len(history) == 2 + assert "[Formatted:" in history[1].content + return history + + +if __name__ == "__main__": + basic = basic_example() + print("Basic example:") + for msg in basic: + print(f" {msg.role}: {msg.content}") + + heterogeneous = with_heterogeneous_components() + print("\nHeterogeneous example:") + for msg in heterogeneous: + print(f" {msg.role}: {msg.content}") + + custom = with_custom_formatter() + print("\nCustom formatter example:") + for msg in custom: + print(f" {msg.role}: {msg.content}") diff --git a/mellea/stdlib/components/__init__.py b/mellea/stdlib/components/__init__.py index ea8c0f371..39aa559e3 100644 --- a/mellea/stdlib/components/__init__.py +++ b/mellea/stdlib/components/__init__.py @@ -10,7 +10,7 @@ TemplateRepresentation, blockify, ) -from .chat import Message, ToolMessage, as_chat_history +from .chat import Message, ToolMessage, as_chat_history, as_generic_chat_history from .docs.document import Document from .instruction import Instruction from .intrinsic import Intrinsic @@ -36,6 +36,7 @@ "ToolMessage", "Transform", "as_chat_history", + "as_generic_chat_history", "blockify", "mify", ] diff --git a/mellea/stdlib/components/chat.py b/mellea/stdlib/components/chat.py index 41d397f95..5c1429759 100644 --- a/mellea/stdlib/components/chat.py +++ b/mellea/stdlib/components/chat.py @@ -3,12 +3,14 @@ Defines ``Message``, the ``Component`` subtype used to represent a single turn in a chat history with a ``role`` (``user``, ``assistant``, ``system``, or ``tool``), text ``content``, and optional ``images`` and ``documents`` attachments. Also provides -``ToolMessage`` (a ``Message`` subclass that carries the tool name and arguments) and -the ``as_chat_history`` utility for converting a ``Context`` into a flat list of -``Message`` objects. +``ToolMessage`` (a ``Message`` subclass that carries the tool name and arguments), and +utilities for converting a ``Context`` into a flat list of ``Message`` objects: +``as_chat_history`` (strict typing) and ``as_generic_chat_history`` (flexible with +configurable formatter). """ -from collections.abc import Iterable, Mapping +import logging +from collections.abc import Callable, Iterable, Mapping from typing import Any, Literal from ...core import ( @@ -22,6 +24,8 @@ ) from .docs.document import Document, _coerce_to_documents +_logger = logging.getLogger(__name__) + class Message(Component["Message"]): """A single Message in a Chat history. @@ -276,3 +280,67 @@ def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None: history = [_to_msg(c) for c in all_ctx_events] assert None not in history, "Could not render this context as a chat history." return history # type: ignore + + +def _default_formatter(obj: Any) -> str: + """Default formatter for unknown component types. + + Logs a warning and converts the object to a string representation. + """ + _logger.warning( + f"Unknown component type {type(obj).__name__} in as_generic_chat_history; " + f"converting to string representation." + ) + return str(obj) + + +def as_generic_chat_history( + ctx: Context, formatter: Callable[[Any], str] | None = None +) -> list[Message]: + """Returns a list of Messages corresponding to a Context, with flexible type handling. + + This function is more permissive than ``as_chat_history()``, allowing arbitrary + component types. Unknown types are converted to strings using a configurable + formatter, making it suitable for general-purpose use where context composition + may be heterogeneous. + + Args: + ctx: A linear ``Context`` that may contain ``Message``, ``ModelOutputThunk``, + or other ``Component`` types. + formatter: Optional callable that converts unknown types to strings. + Defaults to ``_default_formatter`` which logs a warning and stringifies. + + Returns: + List of ``Message`` objects in conversation order. + + Raises: + Exception: If the context history is non-linear and cannot be cast to a + flat list. + """ + if formatter is None: + formatter = _default_formatter + + def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message: + match c: + case Message(): + return c + case ModelOutputThunk(): + if isinstance(c.parsed_repr, Message): + return c.parsed_repr + if isinstance(c.parsed_repr, str): + return Message(role="assistant", content=c.parsed_repr) + # Use value if parsed_repr is None or some other type + content = ( + str(c.value) if c.parsed_repr is None else formatter(c.parsed_repr) + ) + return Message(role="assistant", content=content) + case CBlock(): + return Message(role="user", content=str(c)) + case _: + content = formatter(c) + return Message(role="user", content=content) + + all_ctx_events = ctx.as_list() + if all_ctx_events is None: + raise Exception("Trying to cast a non-linear history into a chat history.") + return [_to_msg(c) for c in all_ctx_events] diff --git a/test/stdlib/components/test_chat.py b/test/stdlib/components/test_chat.py index dae7a787c..d4d5e63e2 100644 --- a/test/stdlib/components/test_chat.py +++ b/test/stdlib/components/test_chat.py @@ -1,10 +1,16 @@ +import logging + import pytest from mellea.core import CBlock, ModelOutputThunk, TemplateRepresentation from mellea.formatters.template_formatter import TemplateFormatter from mellea.helpers import message_to_openai_message, messages_to_docs from mellea.stdlib.components import Document, Message -from mellea.stdlib.components.chat import ToolMessage, as_chat_history +from mellea.stdlib.components.chat import ( + ToolMessage, + as_chat_history, + as_generic_chat_history, +) from mellea.stdlib.context import ChatContext @@ -270,6 +276,130 @@ def test_as_chat_history_with_parsed_mot(): assert history[1].content == "reply" +# --- as_generic_chat_history --- + + +def test_as_generic_chat_history_messages_only(): + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + ctx = ctx.add(Message("assistant", "hi")) + history = as_generic_chat_history(ctx) + assert len(history) == 2 + assert history[0].role == "user" + assert history[0].content == "hello" + assert history[1].role == "assistant" + assert history[1].content == "hi" + + +def test_as_generic_chat_history_empty(): + ctx = ChatContext() + history = as_generic_chat_history(ctx) + assert history == [] + + +def test_as_generic_chat_history_with_parsed_mot(): + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + mot = ModelOutputThunk(value="reply") + mot.parsed_repr = Message("assistant", "reply") + ctx = ctx.add(mot) + history = as_generic_chat_history(ctx) + assert len(history) == 2 + assert history[1].role == "assistant" + assert history[1].content == "reply" + + +def test_as_generic_chat_history_with_unparsed_mot(): + """Unresolved ModelOutputThunk gets converted to string.""" + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + mot = ModelOutputThunk(value="raw output") + ctx = ctx.add(mot) + history = as_generic_chat_history(ctx) + assert len(history) == 2 + assert history[1].role == "assistant" + assert "raw output" in history[1].content + + +def test_as_generic_chat_history_with_string_parsed_repr(): + """ModelOutputThunk with string parsed_repr (e.g., from CBlock action).""" + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + # Simulate a ModelOutputThunk with a string parsed_repr, + # as would result from a CBlock action completing + mot = ModelOutputThunk(value="reply text", parsed_repr="reply text") + ctx = ctx.add(mot) + history = as_generic_chat_history(ctx) + assert len(history) == 2 + assert history[1].role == "assistant" + assert history[1].content == "reply text" + + +def test_as_generic_chat_history_with_non_message_parsed_repr(): + """ModelOutputThunk with non-Message, non-string parsed_repr uses formatter.""" + + def custom_formatter(obj: object) -> str: + if isinstance(obj, dict): + return f"dict:{obj}" + return str(obj) + + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + # parsed_repr is a dict (could be structured data from a model) + mot = ModelOutputThunk(value="raw", parsed_repr={"key": "value"}) + ctx = ctx.add(mot) + history = as_generic_chat_history(ctx, formatter=custom_formatter) + assert len(history) == 2 + assert history[1].role == "assistant" + assert "dict:" in history[1].content + + +def test_as_generic_chat_history_with_cblock(): + """CBlocks are converted to Messages with 'user' role.""" + ctx = ChatContext() + ctx = ctx.add(CBlock("inline content")) + ctx = ctx.add(Message("assistant", "response")) + history = as_generic_chat_history(ctx) + assert len(history) == 2 + assert history[0].role == "user" + assert history[0].content == "inline content" + + +def test_as_generic_chat_history_custom_formatter(): + """Custom formatter handles unknown types.""" + + def custom_formatter(obj: object) -> str: + return f"" + + class CustomComponent: + def __str__(self): + return "original" + + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + ctx = ctx.add(CustomComponent()) + history = as_generic_chat_history(ctx, formatter=custom_formatter) + assert len(history) == 2 + assert "" in history[1].content + + +def test_as_generic_chat_history_default_formatter_logs_warning(caplog): + """Default formatter logs a warning for unknown types.""" + + class UnknownComponent: + pass + + ctx = ChatContext() + ctx = ctx.add(Message("user", "hello")) + ctx = ctx.add(UnknownComponent()) + + with caplog.at_level(logging.WARNING): + history = as_generic_chat_history(ctx) + + assert len(history) == 2 + assert any("Unknown component type" in record.message for record in caplog.records) + + # --- Formatter rendering of Message documents ---