-
Notifications
You must be signed in to change notification settings - Fork 538
Enhance chat completion functionality to support OpenAI-style message history #1674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,15 +17,22 @@ | |
|
|
||
| This module provides a simple completion function that can handle | ||
| natural language queries and perform basic text completion tasks. | ||
| Supports OpenAI-style message history when used with the chat completions API. | ||
| """ | ||
|
|
||
| from pydantic import Field | ||
|
|
||
| from nat.builder.builder import Builder | ||
| from nat.builder.framework_enum import LLMFrameworkEnum | ||
| from nat.builder.function_info import FunctionInfo | ||
| from nat.cli.register_workflow import register_function | ||
| from nat.data_models.api_server import ChatRequest | ||
| from nat.data_models.api_server import ChatRequestOrMessage | ||
| from nat.data_models.api_server import ChatResponse | ||
| from nat.data_models.api_server import Usage | ||
| from nat.data_models.component_ref import LLMRef | ||
| from nat.data_models.function import FunctionBaseConfig | ||
| from nat.utils.type_converter import GlobalTypeConverter | ||
|
|
||
|
|
||
| class ChatCompletionConfig(FunctionBaseConfig, name="chat_completion"): | ||
|
|
@@ -39,39 +46,91 @@ class ChatCompletionConfig(FunctionBaseConfig, name="chat_completion"): | |
| llm_name: LLMRef = Field(description="The LLM to use for generating responses.") | ||
|
|
||
|
|
||
| def _messages_to_langchain_messages( | ||
| nat_messages: list, | ||
| system_prompt: str, | ||
| ): | ||
| """Convert NAT Message list to LangChain BaseMessage list with system prompt prepended if needed.""" | ||
| from langchain_core.messages.utils import convert_to_messages | ||
|
|
||
| message_dicts = [m.model_dump() for m in nat_messages] | ||
| has_system = any(d.get("role") == "system" for d in message_dicts) | ||
| if not has_system and system_prompt: | ||
| message_dicts = [{"role": "system", "content": system_prompt}] + message_dicts | ||
| return convert_to_messages(message_dicts) | ||
|
|
||
|
|
||
| @register_function(config_type=ChatCompletionConfig) | ||
| async def register_chat_completion(config: ChatCompletionConfig, builder: Builder): | ||
| """Registers a chat completion function that can handle natural language queries.""" | ||
| """Registers a chat completion function that can handle natural language queries and full message history.""" | ||
|
|
||
| # Get the LLM from the builder context using the configured LLM reference | ||
| # Use LangChain/LangGraph framework wrapper since we're using LangChain/LangGraph-based LLM | ||
| llm = await builder.get_llm(config.llm_name, wrapper_type=LLMFrameworkEnum.LANGCHAIN) | ||
|
|
||
| async def _chat_completion(query: str) -> str: | ||
| """A simple chat completion function that responds to natural language queries. | ||
| async def _chat_completion(chat_request_or_message: ChatRequestOrMessage) -> ChatResponse | str: | ||
| """Chat completion that supports OpenAI-style message history. | ||
|
|
||
| Accepts either a single input_message (string) or a full conversation | ||
| (messages array). When messages are provided, the full history is sent | ||
| to the LLM for context-aware responses. | ||
|
|
||
| Args: | ||
| query: The user's natural language query | ||
| chat_request_or_message: Either a string input or OpenAI-style messages array. | ||
|
|
||
| Returns: | ||
| A helpful response to the query | ||
| ChatResponse when input is a conversation; str when input is a single message. | ||
| """ | ||
| try: | ||
| # Create a simple prompt with the system message and user query | ||
| prompt = f"{config.system_prompt}\n\nUser: {query}\n\nAssistant:" | ||
| message = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) | ||
|
|
||
| # Generate response using the LLM | ||
| response = await llm.ainvoke(prompt) | ||
| # Build LangChain message list from full conversation (OpenAI message history) | ||
| lc_messages = _messages_to_langchain_messages( | ||
| message.messages, | ||
| config.system_prompt, | ||
| ) | ||
|
|
||
| if isinstance(response, str): | ||
| return response | ||
| # Generate response using the LLM with full message history | ||
| response = await llm.ainvoke(lc_messages) | ||
|
|
||
| return response.text() | ||
| if isinstance(response, str): | ||
| output_text = response | ||
| else: | ||
| output_text = response.text() if hasattr(response, "text") else str(response.content) | ||
|
|
||
| # Approximate usage for API compatibility | ||
| prompt_tokens = sum( | ||
| len(str(m.content).split()) for m in message.messages | ||
| ) | ||
| completion_tokens = len(output_text.split()) if output_text else 0 | ||
| total_tokens = prompt_tokens + completion_tokens | ||
| usage = Usage( | ||
| prompt_tokens=prompt_tokens, | ||
| completion_tokens=completion_tokens, | ||
| total_tokens=total_tokens, | ||
| ) | ||
| chat_response = ChatResponse.from_string(output_text, usage=usage) | ||
|
|
||
| if chat_request_or_message.is_string: | ||
| return GlobalTypeConverter.get().convert(chat_response, to_type=str) | ||
| return chat_response | ||
|
|
||
| except Exception as e: | ||
| # Fallback response if LLM call fails | ||
| return (f"I apologize, but I encountered an error while processing your " | ||
| f"query: '{query}'. Please try rephrasing your question or try " | ||
| f"again later. Error: {str(e)}") | ||
|
|
||
| yield _chat_completion | ||
| last_content = "" | ||
| try: | ||
| msg = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest) | ||
| if msg.messages: | ||
| last = msg.messages[-1].content | ||
| last_content = last if isinstance(last, str) else str(last) | ||
| except Exception: | ||
| pass | ||
| return ( | ||
| f"I apologize, but I encountered an error while processing your " | ||
| f"query: '{last_content}'. Please try rephrasing your question or try " | ||
| f"again later. Error: {str(e)}" | ||
| ) | ||
|
Comment on lines
118
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add logging for caught exceptions and avoid exposing raw error details to users. The exception handling has several issues:
Proposed fix+import logging
+
+logger = logging.getLogger(__name__)
+
except Exception as e:
+ logger.exception("Error processing chat completion request")
last_content = ""
try:
msg = GlobalTypeConverter.get().convert(chat_request_or_message, to_type=ChatRequest)
if msg.messages:
last = msg.messages[-1].content
last_content = last if isinstance(last, str) else str(last)
- except Exception:
- pass
+ except Exception:
+ logger.debug("Could not extract last message content for error reporting")
return (
f"I apologize, but I encountered an error while processing your "
- f"query: '{last_content}'. Please try rephrasing your question or try "
- f"again later. Error: {str(e)}"
+ f"query: '{last_content}'. Please try rephrasing your question or try again later."
)Removing the raw exception from the user-facing message prevents potential information leakage while the logged exception preserves full debugging context. 🧰 Tools🪛 Ruff (0.15.2)[warning] 118-118: Do not catch blind exception: (BLE001) [error] 125-126: (S110) [warning] 125-125: Do not catch blind exception: (BLE001) [warning] 130-130: Use explicit conversion flag Replace with conversion flag (RUF010) 🤖 Prompt for AI Agents |
||
|
|
||
| yield FunctionInfo.from_fn( | ||
| _chat_completion, | ||
| description=getattr(config, "description", "Chat completion"), | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 2694
🏁 Script executed:
# Check imports to understand what types are being used head -30 packages/nvidia_nat_core/src/nat/tool/chat_completion.pyRepository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 1351
🏁 Script executed:
Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 3934
🏁 Script executed:
Repository: NVIDIA/NeMo-Agent-Toolkit
Length of output: 898
🌐 Web query:
LangChain BaseMessage content property API documentation💡 Result:
Python (
langchain_core.messages.base.BaseMessage):contentis a required field (str | list[str | dict]) holding the message payload;BaseMessageis the abstract base forHumanMessage,AIMessage, etc. See the API doc forBaseMessage(params/attributes includecontent,additional_kwargs,response_metadata,type, etc.). [1]JavaScript/TypeScript (
@langchain/coremessages): message classes (e.g.,AIMessage) inherit fromBaseMessageand expose acontentproperty whose type supports either plain text or structured content blocks (LangChain v1 message content blocks). SeeAIMessage(showscontenttype) and the Messages docs (explains whatcontentcan contain). [2], [3]Sources: [1] [2] [3]
Simplify message content extraction to use LangChain's
contentproperty directly.LangChain's
BaseMessageuses acontentproperty (str | list), not atext()method. Thehasattr(response, "text")check will always be False for LangChain messages. Simplify line 99 to:output_text = response.content if isinstance(response.content, str) else str(response.content)(or simplyoutput_text = response.contentif content is guaranteed to be a string).🤖 Prompt for AI Agents