diff --git a/willa/chatbot/graph_manager.py b/willa/chatbot/graph_manager.py index 13dea78..b4b5809 100644 --- a/willa/chatbot/graph_manager.py +++ b/willa/chatbot/graph_manager.py @@ -1,7 +1,8 @@ """Manages the shared state and workflow for Willa chatbots.""" -from typing import Any, Optional, Annotated, NotRequired +from typing import Optional, Annotated, NotRequired from typing_extensions import TypedDict +from langchain_core.documents import Document from langchain_core.language_models import BaseChatModel from langchain_core.messages import ChatMessage, HumanMessage, AIMessage from langchain_core.vectorstores.base import VectorStore @@ -19,10 +20,10 @@ class WillaChatbotState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] filtered_messages: NotRequired[list[AnyMessage]] summarized_messages: NotRequired[list[AnyMessage]] - docs_context: NotRequired[str] + messages_for_generation: NotRequired[list[AnyMessage]] search_query: NotRequired[str] tind_metadata: NotRequired[str] - context: NotRequired[dict[str, Any]] + documents: NotRequired[list[dict[str, str]]] class GraphManager: # pylint: disable=too-few-public-methods @@ -51,13 +52,15 @@ def _create_workflow(self) -> CompiledStateGraph: workflow.add_node("summarize", summarization_node) workflow.add_node("prepare_search", self._prepare_search_query) workflow.add_node("retrieve_context", self._retrieve_context) + workflow.add_node("prepare_for_generation", self._prepare_for_generation) workflow.add_node("generate_response", self._generate_response) # Define edges workflow.add_edge("filter_messages", "summarize") workflow.add_edge("summarize", "prepare_search") workflow.add_edge("prepare_search", "retrieve_context") - workflow.add_edge("retrieve_context", "generate_response") + workflow.add_edge("retrieve_context", "prepare_for_generation") + workflow.add_edge("prepare_for_generation", "generate_response") workflow.set_entry_point("filter_messages") workflow.set_finish_point("generate_response") @@ -68,7 +71,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag """Filter out TIND messages from the conversation history.""" messages = state["messages"] - filtered = [msg for msg in messages if 'tind' not in msg.response_metadata] + filtered: list[AnyMessage] = [ + msg for msg in messages + if "tind" not in getattr(msg, "response_metadata", {}) and msg.type != "system" + ] return {"filtered_messages": filtered} def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]: @@ -79,60 +85,83 @@ def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]: # summarization may include a system message as well as any human or ai messages search_query = '\n'.join(str(msg.content) for msg in messages if hasattr(msg, 'content')) + + # if summarization fails or some other issue, truncate to the last 2048 characters + if len(search_query) > 2048: + search_query = search_query[-2048:] + return {"search_query": search_query} - def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]: + def _format_retrieved_documents(self, matching_docs: list[Document]) -> list[dict[str, str]]: + """Format documents from vector store into a list of dictionaries.""" + formatted_documents: list[dict[str, str]] = [] + for i, doc in enumerate(matching_docs, 1): + tind_metadata = doc.metadata.get('tind_metadata', {}) + tind_id = tind_metadata.get('tind_id', [''])[0] + formatted_documents.append({ + "id": f"{i}_{tind_id}", + "page_content": doc.page_content, + "title": tind_metadata.get('title', [''])[0], + "project": tind_metadata.get('isPartOf', [''])[0], + "tind_link": format_tind_context.get_tind_url(tind_id) + }) + return formatted_documents + + def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str | list[dict[str, str]]]: """Retrieve relevant context from vector store.""" search_query = state.get("search_query", "") vector_store = self._vector_store if not search_query or not vector_store: - return {"docs_context": "", "tind_metadata": ""} + return {"tind_metadata": "", "documents": []} # Search for relevant documents retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])}) matching_docs = retriever.invoke(search_query) + formatted_documents = self._format_retrieved_documents(matching_docs) - # Format context and metadata - docs_context = '\n\n'.join(doc.page_content for doc in matching_docs) + # Format tind metadata tind_metadata = format_tind_context.get_tind_context(matching_docs) - return {"docs_context": docs_context, "tind_metadata": tind_metadata} + return {"tind_metadata": tind_metadata, "documents": formatted_documents} - # This should be refactored probably. Very bulky - def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: - """Generate response using the model.""" + def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: + """Prepare the current and past messages for response generation.""" messages = state["messages"] summarized_conversation = state.get("summarized_messages", messages) - docs_context = state.get("docs_context", "") - tind_metadata = state.get("tind_metadata", "") - model = self._model - - if not model: - return {"messages": [AIMessage(content="Model not available.")]} - - # Get the latest human message - latest_message = next( - (msg for msg in reversed(messages) if isinstance(msg, HumanMessage)), - None - ) - if not latest_message: + if not any(isinstance(msg, HumanMessage) for msg in messages): return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]} prompt = get_langfuse_prompt() - system_messages = prompt.invoke({'context': docs_context, - 'question': latest_message.content}) + system_messages = prompt.invoke({}) + if hasattr(system_messages, "messages"): all_messages = summarized_conversation + system_messages.messages else: all_messages = summarized_conversation + [system_messages] + return {"messages_for_generation": all_messages} + + def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: + """Generate response using the model.""" + tind_metadata = state.get("tind_metadata", "") + model = self._model + documents = state.get("documents", []) + messages = state.get("messages_for_generation") or state.get("messages", []) + + if not model: + return {"messages": [AIMessage(content="Model not available.")]} + # Get response from model - response = model.invoke(all_messages) + response = model.invoke( + messages, + additional_model_request_fields={"documents": documents} + ) # Create clean response content response_content = str(response.content) if hasattr(response, 'content') else str(response) + response_messages: list[AnyMessage] = [AIMessage(content=response_content), ChatMessage(content=tind_metadata, role='TIND', response_metadata={'tind': True})]