Skip to content
81 changes: 52 additions & 29 deletions willa/chatbot/graph_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class WillaChatbotState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
filtered_messages: NotRequired[list[AnyMessage]]
summarized_messages: NotRequired[list[AnyMessage]]
docs_context: NotRequired[str]
messages_for_generation: NotRequired[list[AnyMessage]]
search_query: NotRequired[str]
tind_metadata: NotRequired[str]
context: NotRequired[dict[str, Any]]
documents: NotRequired[list[Any]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we want list[Any] here? is the intention that documents will contain either a list of dicts or an empty list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. That is indeed the expectation. I'll change it.



class GraphManager: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -51,13 +51,15 @@ def _create_workflow(self) -> CompiledStateGraph:
workflow.add_node("summarize", summarization_node)
workflow.add_node("prepare_search", self._prepare_search_query)
workflow.add_node("retrieve_context", self._retrieve_context)
workflow.add_node("prepare_for_generation", self._prepare_for_generation)
workflow.add_node("generate_response", self._generate_response)

# Define edges
workflow.add_edge("filter_messages", "summarize")
workflow.add_edge("summarize", "prepare_search")
workflow.add_edge("prepare_search", "retrieve_context")
workflow.add_edge("retrieve_context", "generate_response")
workflow.add_edge("retrieve_context", "prepare_for_generation")
workflow.add_edge("prepare_for_generation", "generate_response")

workflow.set_entry_point("filter_messages")
workflow.set_finish_point("generate_response")
Expand All @@ -68,7 +70,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag
"""Filter out TIND messages from the conversation history."""
messages = state["messages"]

filtered = [msg for msg in messages if 'tind' not in msg.response_metadata]
filtered = [
msg for msg in messages
if 'tind' not in msg.response_metadata and msg.type != "system"
]
return {"filtered_messages": filtered}

def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:
Expand All @@ -79,60 +84,78 @@ def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]:

# summarization may include a system message as well as any human or ai messages
search_query = '\n'.join(str(msg.content) for msg in messages if hasattr(msg, 'content'))

# if summarization fails or some other issue, truncate to the last 2048 characters
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the conditions under which summarization can fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I added this line, summarization was failing. It's possible when not properly filtering what we send to summarization, and possibly with the length of our prompts.

if len(search_query) > 2048:
search_query = search_query[-2048:]

return {"search_query": search_query}

def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]:
def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str | list[Any]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure we want list[Any] here? is the intention that documents will contain either a list of dicts or an empty list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. That is indeed the expectation. I'll change it.

"""Retrieve relevant context from vector store."""
search_query = state.get("search_query", "")
vector_store = self._vector_store

if not search_query or not vector_store:
return {"docs_context": "", "tind_metadata": ""}
return {"tind_metadata": "", "documents": []}

# Search for relevant documents
retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])})
matching_docs = retriever.invoke(search_query)

# Format context and metadata
docs_context = '\n\n'.join(doc.page_content for doc in matching_docs)
formatted_documents = [
{
"id": f"{i}_{doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]}",
"page_content": doc.page_content,
"title": doc.metadata.get('tind_metadata', {}).get('title', [''])[0],
"project": doc.metadata.get('tind_metadata', {}).get('isPartOf', [''])[0],
"tind_link": format_tind_context.get_tind_url(
doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0])
}
for i, doc in enumerate(matching_docs, 1)
Comment on lines +105 to +114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be simplified by having something like

tind_metadata = doc.metadata.get('tind_metadata', {})

before the formatted_documents declaration. Could DRY it up.

]

# Format tind metadata
tind_metadata = format_tind_context.get_tind_context(matching_docs)

return {"docs_context": docs_context, "tind_metadata": tind_metadata}
return {"tind_metadata": tind_metadata, "documents": formatted_documents}

# This should be refactored probably. Very bulky
def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Generate response using the model."""
def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Prepare the current and past messages for response generation."""
messages = state["messages"]
summarized_conversation = state.get("summarized_messages", messages)
docs_context = state.get("docs_context", "")
tind_metadata = state.get("tind_metadata", "")
model = self._model

if not model:
return {"messages": [AIMessage(content="Model not available.")]}

# Get the latest human message
latest_message = next(
(msg for msg in reversed(messages) if isinstance(msg, HumanMessage)),
None
)

if not latest_message:
if not any(isinstance(msg, HumanMessage) for msg in messages):
return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]}

prompt = get_langfuse_prompt()
system_messages = prompt.invoke({'context': docs_context,
'question': latest_message.content})
system_messages = prompt.invoke({})

if hasattr(system_messages, "messages"):
all_messages = summarized_conversation + system_messages.messages
else:
all_messages = summarized_conversation + [system_messages]

return {"messages_for_generation": all_messages}

def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]:
"""Generate response using the model."""
tind_metadata = state.get("tind_metadata", "")
model = self._model
documents = state.get("documents", [])
messages = state["messages_for_generation"]

if not model:
return {"messages": [AIMessage(content="Model not available.")]}

# Get response from model
response = model.invoke(all_messages)
response = model.invoke(
messages,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be completely clear: Cohere does the right thing here and uses the last provided HumanMessage as the question input? (is this documented anywhere?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. I'll have to look in the docs. I would expect it to take the last HumanMessage as the query.

additional_model_request_fields={"documents": documents}
)

# Create clean response content
response_content = str(response.content) if hasattr(response, 'content') else str(response)

response_messages: list[AnyMessage] = [AIMessage(content=response_content),
ChatMessage(content=tind_metadata, role='TIND',
response_metadata={'tind': True})]
Expand Down