-
Notifications
You must be signed in to change notification settings - Fork 0
investigating separating out documents from the rest of the message h… #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6f9bba6
3f81a6c
8869188
ea8e7fb
595169c
51c264e
140ccc0
a8579b8
21904f6
53ea418
1843a05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,10 +19,10 @@ class WillaChatbotState(TypedDict): | |
| messages: Annotated[list[AnyMessage], add_messages] | ||
| filtered_messages: NotRequired[list[AnyMessage]] | ||
| summarized_messages: NotRequired[list[AnyMessage]] | ||
| docs_context: NotRequired[str] | ||
| messages_for_generation: NotRequired[list[AnyMessage]] | ||
| search_query: NotRequired[str] | ||
| tind_metadata: NotRequired[str] | ||
| context: NotRequired[dict[str, Any]] | ||
| documents: NotRequired[list[Any]] | ||
|
|
||
jason-raitz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| class GraphManager: # pylint: disable=too-few-public-methods | ||
|
|
@@ -51,13 +51,15 @@ def _create_workflow(self) -> CompiledStateGraph: | |
| workflow.add_node("summarize", summarization_node) | ||
| workflow.add_node("prepare_search", self._prepare_search_query) | ||
| workflow.add_node("retrieve_context", self._retrieve_context) | ||
| workflow.add_node("prepare_for_generation", self._prepare_for_generation) | ||
| workflow.add_node("generate_response", self._generate_response) | ||
|
|
||
| # Define edges | ||
| workflow.add_edge("filter_messages", "summarize") | ||
| workflow.add_edge("summarize", "prepare_search") | ||
| workflow.add_edge("prepare_search", "retrieve_context") | ||
| workflow.add_edge("retrieve_context", "generate_response") | ||
| workflow.add_edge("retrieve_context", "prepare_for_generation") | ||
| workflow.add_edge("prepare_for_generation", "generate_response") | ||
|
|
||
| workflow.set_entry_point("filter_messages") | ||
| workflow.set_finish_point("generate_response") | ||
|
|
@@ -68,7 +70,10 @@ def _filter_messages(self, state: WillaChatbotState) -> dict[str, list[AnyMessag | |
| """Filter out TIND messages from the conversation history.""" | ||
| messages = state["messages"] | ||
|
|
||
| filtered = [msg for msg in messages if 'tind' not in msg.response_metadata] | ||
| filtered = [ | ||
| msg for msg in messages | ||
| if 'tind' not in msg.response_metadata and msg.type != "system" | ||
| ] | ||
| return {"filtered_messages": filtered} | ||
|
|
||
| def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]: | ||
|
|
@@ -79,60 +84,78 @@ def _prepare_search_query(self, state: WillaChatbotState) -> dict[str, str]: | |
|
|
||
| # summarization may include a system message as well as any human or ai messages | ||
| search_query = '\n'.join(str(msg.content) for msg in messages if hasattr(msg, 'content')) | ||
|
|
||
| # if summarization fails or some other issue, truncate to the last 2048 characters | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are the conditions under which summarization can fail?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When I added this line, summarization was failing. It's possible when not properly filtering what we send to summarization, and possibly with the length of our prompts. |
||
| if len(search_query) > 2048: | ||
| search_query = search_query[-2048:] | ||
|
|
||
| return {"search_query": search_query} | ||
|
|
||
| def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str]: | ||
| def _retrieve_context(self, state: WillaChatbotState) -> dict[str, str | list[Any]]: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we sure we want
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. That is indeed the expectation. I'll change it. |
||
| """Retrieve relevant context from vector store.""" | ||
| search_query = state.get("search_query", "") | ||
| vector_store = self._vector_store | ||
|
|
||
| if not search_query or not vector_store: | ||
| return {"docs_context": "", "tind_metadata": ""} | ||
| return {"tind_metadata": "", "documents": []} | ||
|
|
||
| # Search for relevant documents | ||
| retriever = vector_store.as_retriever(search_kwargs={"k": int(CONFIG['K_VALUE'])}) | ||
| matching_docs = retriever.invoke(search_query) | ||
|
|
||
| # Format context and metadata | ||
| docs_context = '\n\n'.join(doc.page_content for doc in matching_docs) | ||
| formatted_documents = [ | ||
| { | ||
| "id": f"{i}_{doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]}", | ||
| "page_content": doc.page_content, | ||
| "title": doc.metadata.get('tind_metadata', {}).get('title', [''])[0], | ||
| "project": doc.metadata.get('tind_metadata', {}).get('isPartOf', [''])[0], | ||
| "tind_link": format_tind_context.get_tind_url( | ||
| doc.metadata.get('tind_metadata', {}).get('tind_id', [''])[0]) | ||
| } | ||
| for i, doc in enumerate(matching_docs, 1) | ||
|
Comment on lines
+105
to
+114
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be simplified by having something like before the |
||
| ] | ||
|
|
||
| # Format tind metadata | ||
| tind_metadata = format_tind_context.get_tind_context(matching_docs) | ||
|
|
||
| return {"docs_context": docs_context, "tind_metadata": tind_metadata} | ||
| return {"tind_metadata": tind_metadata, "documents": formatted_documents} | ||
|
|
||
| # This should be refactored probably. Very bulky | ||
| def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: | ||
| """Generate response using the model.""" | ||
| def _prepare_for_generation(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: | ||
| """Prepare the current and past messages for response generation.""" | ||
| messages = state["messages"] | ||
| summarized_conversation = state.get("summarized_messages", messages) | ||
| docs_context = state.get("docs_context", "") | ||
| tind_metadata = state.get("tind_metadata", "") | ||
| model = self._model | ||
|
|
||
| if not model: | ||
| return {"messages": [AIMessage(content="Model not available.")]} | ||
|
|
||
| # Get the latest human message | ||
| latest_message = next( | ||
| (msg for msg in reversed(messages) if isinstance(msg, HumanMessage)), | ||
| None | ||
| ) | ||
|
|
||
| if not latest_message: | ||
| if not any(isinstance(msg, HumanMessage) for msg in messages): | ||
| return {"messages": [AIMessage(content="I'm sorry, I didn't receive a question.")]} | ||
|
|
||
| prompt = get_langfuse_prompt() | ||
| system_messages = prompt.invoke({'context': docs_context, | ||
| 'question': latest_message.content}) | ||
| system_messages = prompt.invoke({}) | ||
anarchivist marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if hasattr(system_messages, "messages"): | ||
| all_messages = summarized_conversation + system_messages.messages | ||
| else: | ||
| all_messages = summarized_conversation + [system_messages] | ||
|
|
||
| return {"messages_for_generation": all_messages} | ||
|
|
||
| def _generate_response(self, state: WillaChatbotState) -> dict[str, list[AnyMessage]]: | ||
| """Generate response using the model.""" | ||
| tind_metadata = state.get("tind_metadata", "") | ||
| model = self._model | ||
| documents = state.get("documents", []) | ||
| messages = state["messages_for_generation"] | ||
|
|
||
| if not model: | ||
| return {"messages": [AIMessage(content="Model not available.")]} | ||
|
|
||
| # Get response from model | ||
| response = model.invoke(all_messages) | ||
| response = model.invoke( | ||
| messages, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be completely clear: Cohere does the right thing here and uses the last provided
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure. I'll have to look in the docs. I would expect it to take the last HumanMessage as the query. |
||
| additional_model_request_fields={"documents": documents} | ||
| ) | ||
|
|
||
| # Create clean response content | ||
| response_content = str(response.content) if hasattr(response, 'content') else str(response) | ||
|
|
||
| response_messages: list[AnyMessage] = [AIMessage(content=response_content), | ||
| ChatMessage(content=tind_metadata, role='TIND', | ||
| response_metadata={'tind': True})] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure we want
list[Any]here? is the intention that documents will contain either a list ofdicts or an empty list?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. That is indeed the expectation. I'll change it.