Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 108 additions & 67 deletions src/api/app/functions/chat_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ async def __create_query_embeddings(self, user_query: str):
query_embeddings = await self.embedding_client.aembed_query(user_query)
return query_embeddings

async def __execute_query(self, query: str):
async def __execute_query(self, query: str, *params):
"""
Executes a query on the database and returns the results.
"""
# Acquire a connection from the pool and execute the query
async with self.pool.acquire() as conn:
rows = await conn.fetch(query)
rows = await conn.fetch(query, *params)
return [dict(row) for row in rows]

async def __execute_graph_query(self, query: str):
async def __execute_graph_query(self, query: str, *params):
"""
Executes a query on the database and returns the results.
"""
Expand All @@ -31,24 +31,24 @@ async def __execute_graph_query(self, query: str):
# Execute a query to set the search path on the connection
await conn.execute('SET search_path = ag_catalog, "$user", public;')
# Execute the graph query
rows = await conn.fetch(query)
rows = await conn.fetch(query, *params)
return [dict(row) for row in rows]

async def __execute_scalar_query(self, query: str):
async def __execute_scalar_query(self, query: str, *params):
"""
Executes a scalar query on the database and returns the result.
"""
# Acquire a connection from the pool and execute the query
async with self.pool.acquire() as conn:
row = await conn.fetchrow(query)
row = await conn.fetchrow(query, *params)
return row

async def get_invoice_id(self, number: str) -> int:
"""
Retrieves the ID of a specific invoice by its number.
"""
query = f"SELECT id FROM invoices WHERE number = '{number}';"
row = await self.__execute_scalar_query(query)
query = "SELECT id FROM invoices WHERE number = $1;"
row = await self.__execute_scalar_query(query, number)
return row['id'] or None

async def get_invoice_line_items(self, invoice_id: int):
Expand All @@ -58,9 +58,9 @@ async def get_invoice_line_items(self, invoice_id: int):
# Define the columns to retrieve from the table
# Exclude the embedding column in results
columns = ["id", "invoice_id", "description", "amount", "status"]
query = f'SELECT {", ".join(columns)} FROM invoice_line_items WHERE invoice_id = {invoice_id};'
query = f'SELECT {", ".join(columns)} FROM invoice_line_items WHERE invoice_id = $1;'

rows = await self.__execute_query(query)
rows = await self.__execute_query(query, invoice_id)
return [dict(row) for row in rows]

async def get_invoice_validation_results(self, invoice_id: int = None):
Expand All @@ -72,11 +72,13 @@ async def get_invoice_validation_results(self, invoice_id: int = None):
# This excludes the embedding column in results
columns = ["invoice_id", "datestamp", "result", "validation_passed"]
query = f'SELECT {", ".join(columns)} FROM invoice_validation_results'
params = []
# Filter the validation results by invoice_id
if invoice_id is not None:
query += f' WHERE invoice_id = {invoice_id}'
query += ' WHERE invoice_id = $1'
params.append(invoice_id)

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def get_invoices(self, invoice_id: int = None, vendor_id: int = None, sow_id: int = None):
Expand All @@ -88,19 +90,30 @@ async def get_invoices(self, invoice_id: int = None, vendor_id: int = None, sow_
# This excludes a few columns that are large and not needed for the chat function
columns = ["id", "number", "vendor_id", "sow_id", "amount", "invoice_date", "payment_status"]
query = f'SELECT {", ".join(columns)} FROM invoices'

params = []
where_clauses = []
param_count = 1

# Filter the invoices by invoice_id, vendor_id or sow_id, if provided
if invoice_id is not None:
query += f' WHERE id = {invoice_id}'
where_clauses.append(f'id = ${param_count}')
params.append(invoice_id)
param_count += 1
else:
if vendor_id is not None:
query += f' WHERE vendor_id = {vendor_id}'
if sow_id is not None:
query += f' AND sow_id = {sow_id}'
elif sow_id is not None:
query += f' WHERE sow_id = {sow_id}'
where_clauses.append(f'vendor_id = ${param_count}')
params.append(vendor_id)
param_count += 1
if sow_id is not None:
where_clauses.append(f'sow_id = ${param_count}')
params.append(sow_id)
param_count += 1

if where_clauses:
query += ' WHERE ' + ' AND '.join(where_clauses)

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def get_unpaid_invoices_for_vendor(self, vendor_id: int):
Expand All @@ -109,20 +122,20 @@ async def get_unpaid_invoices_for_vendor(self, vendor_id: int):
"""
# Define the graph query
graph_query = f"""SELECT * FROM ag_catalog.cypher('vendor_graph', $$
MATCH (v:vendor {{id: '{vendor_id}'}})-[rel:has_invoices]->(s:sow)
MATCH (v:vendor {{id: $1}})-[rel:has_invoices]->(s:sow)
WHERE rel.payment_status <> 'Paid'
RETURN v.id AS vendor_id, v.name AS vendor_name, s.id AS sow_id, s.number AS sow_number, rel.id AS invoice_id, rel.number AS invoice_number, rel.payment_status AS payment_status
$$) as (vendor_id BIGINT, vendor_name TEXT, sow_id BIGINT, sow_number TEXT, invoice_id BIGINT, invoice_number TEXT, payment_status TEXT);
$$, $2) as (vendor_id BIGINT, vendor_name TEXT, sow_id BIGINT, sow_number TEXT, invoice_id BIGINT, invoice_number TEXT, payment_status TEXT);
"""
rows = await self.__execute_graph_query(graph_query)
rows = await self.__execute_graph_query(graph_query, str(vendor_id), f'[{{"id": "{vendor_id}"}}]')
return [dict(row) for row in rows]

async def get_sow_id(self, number: str) -> int:
"""
Retrieves the ID of a specific SOW by its number.
"""
query = f"SELECT id FROM sows WHERE number = '{number}';"
row = await self.__execute_scalar_query(query)
query = "SELECT id FROM sows WHERE number = $1;"
row = await self.__execute_scalar_query(query, number)
return row.get('id', None)

async def get_sow_chunks(self, sow_id: int):
Expand All @@ -134,17 +147,17 @@ async def get_sow_chunks(self, sow_id: int):
# Define the columns to retrieve from the table
# This excludes the embedding column in results
columns = ["id", "sow_id", "heading", "content", "page_number"]
query = f'SELECT {", ".join(columns)} FROM sow_chunks WHERE sow_id = {sow_id};'
query = f'SELECT {", ".join(columns)} FROM sow_chunks WHERE sow_id = $1;'

rows = await self.__execute_query(query)
rows = await self.__execute_query(query, sow_id)
return [dict(row) for row in rows]

async def get_sow_milestones(self, sow_id: int):
"""
Retrieves a list of milestones for a specific statement of work (SOW) by its ID.
"""
query = f'SELECT * FROM milestones WHERE sow_id = {sow_id};'
rows = await self.__execute_query(query)
query = 'SELECT * FROM milestones WHERE sow_id = $1;'
rows = await self.__execute_query(query, sow_id)
return [dict(row) for row in rows]

async def get_milestone_deliverables(self, milestone_id: int):
Expand All @@ -154,9 +167,9 @@ async def get_milestone_deliverables(self, milestone_id: int):
# Define the columns to retrieve from the table
# This excludes the embedding column in results
columns = ["id", "milestone_id", "description", "amount", "status", "due_date"]
query = f'SELECT {", ".join(columns)} FROM deliverables WHERE milestone_id = {milestone_id}'
query = f'SELECT {", ".join(columns)} FROM deliverables WHERE milestone_id = $1'

rows = await self.__execute_query(f'{query}')
rows = await self.__execute_query(query, milestone_id)
return [dict(row) for row in rows]

async def get_sow_validation_results(self, sow_id: int = None):
Expand All @@ -168,11 +181,13 @@ async def get_sow_validation_results(self, sow_id: int = None):
# This excludes the embedding column in results
columns = ["sow_id", "datestamp", "result", "validation_passed"]
query = f'SELECT {", ".join(columns)} FROM sow_validation_results'
params = []
# Filter the validation results by sow_id
if sow_id is not None:
query += f' WHERE sow_id = {sow_id}'
query += ' WHERE sow_id = $1'
params.append(sow_id)

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def get_sows(self, sow_id: int = None, vendor_id: int = None):
Expand All @@ -186,13 +201,25 @@ async def get_sows(self, sow_id: int = None, vendor_id: int = None):

# Build a SELECT query and JOIN from the tables and columns
query = f'SELECT {", ".join(columns)} FROM sows'

params = []
where_clauses = []
param_count = 1

# Filter the SOWs by vendor_id, if provided
if sow_id is not None:
query += f' WHERE id = {sow_id}'
where_clauses.append(f'id = ${param_count}')
params.append(sow_id)
param_count += 1
elif vendor_id is not None:
query += f' WHERE vendor_id = {vendor_id}'
where_clauses.append(f'vendor_id = ${param_count}')
params.append(vendor_id)
param_count += 1

rows = await self.__execute_query(f'{query};')
if where_clauses:
query += ' WHERE ' + ' AND '.join(where_clauses)

rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def get_vendors(self):
Expand All @@ -218,18 +245,20 @@ async def find_milestone_deliverables(self, user_query: str, sow_id: int = None)

# Use hybrid search to rank records, with exact matches ranked highest
columns.append(f"""CASE
WHEN description ILIKE '%{user_query}%' THEN 0
ELSE (embedding <=> '{query_embeddings}')::real
WHEN description ILIKE $1 THEN 0
ELSE (embedding <=> $2)::real
END as rank""")

query = f'SELECT {", ".join(columns)} FROM deliverables'
params = [f'%{user_query}%', query_embeddings]
# Filter the deliverables by sow_id, if provided
if sow_id is not None:
query += f' WHERE sow_id = {sow_id}'
query += ' WHERE sow_id = $3'
params.append(sow_id)

query += f' ORDER BY rank ASC'
query += ' ORDER BY rank ASC'

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def find_invoice_line_items(self, user_query: str, invoice_id: int = None):
Expand All @@ -246,18 +275,20 @@ async def find_invoice_line_items(self, user_query: str, invoice_id: int = None)

# Use hybrid search to rank records, with exact matches ranked highest
columns.append(f"""CASE
WHEN description ILIKE '%{user_query}%' THEN 0
ELSE (embedding <=> '{query_embeddings}')::real
WHEN description ILIKE $1 THEN 0
ELSE (embedding <=> $2)::real
END as rank""")

query = f'SELECT {", ".join(columns)} FROM invoice_line_items'
params = [f'%{user_query}%', query_embeddings]
# Filter the line items by invoice_id, if provided
if invoice_id is not None:
query += f' WHERE invoice_id = {invoice_id}'
query += ' WHERE invoice_id = $3'
params.append(invoice_id)

query += f' ORDER BY rank ASC'
query += ' ORDER BY rank ASC'

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def find_invoice_validation_results(self, user_query: str, invoice_id: int = None):
Expand All @@ -274,20 +305,22 @@ async def find_invoice_validation_results(self, user_query: str, invoice_id: int

# Use hybrid search to rank records, with exact matches ranked highest
columns.append(f"""CASE
WHEN result ILIKE '%{user_query}%' THEN 0
ELSE (embedding <=> '{query_embeddings}')::real
WHEN result ILIKE $1 THEN 0
ELSE (embedding <=> $2)::real
END as rank""")

query = f'SELECT {", ".join(columns)} FROM invoice_validation_results'

params = [f'%{user_query}%', query_embeddings]
# Filter by invoice_id, if provided
if invoice_id is not None:
query += f' WHERE invoice_id = {invoice_id}'
query += ' WHERE invoice_id = $3'
params.append(invoice_id)

# Order the results by rank
query += f' ORDER BY rank ASC'
query += ' ORDER BY rank ASC'

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def find_sow_chunks(self, user_query: str, sow_id: int = None):
Expand All @@ -303,17 +336,19 @@ async def find_sow_chunks(self, user_query: str, sow_id: int = None):

# Use hybrid search to rank records, with exact matches ranked highest
columns.append(f"""CASE
WHEN content ILIKE '%{user_query}%' THEN 0
ELSE (embedding <=> '{query_embeddings}')::real
WHEN content ILIKE $1 THEN 0
ELSE (embedding <=> $2)::real
END as rank""")

query = f'SELECT {", ".join(columns)} FROM sow_chunks'
params = [f'%{user_query}%', query_embeddings]
if sow_id is not None:
query += f' WHERE sow_id = {sow_id}'
query += ' WHERE sow_id = $3'
params.append(sow_id)

query += f' ORDER BY rank ASC'
query += ' ORDER BY rank ASC'

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]

async def find_sow_chunks_with_semantic_ranking(self, user_query: str, sow_id: int = None, max_results: int = 3):
Expand All @@ -325,23 +360,27 @@ async def find_sow_chunks_with_semantic_ranking(self, user_query: str, sow_id: i
query_embeddings = await self.__create_query_embeddings(user_query)

# Create a vector search query
cte_query = f"SELECT content FROM sow_chunks"
cte_query += f" WHERE sow_id = {sow_id}" if sow_id is not None else ""
cte_query += f" ORDER BY embedding <=> '{query_embeddings}'"
cte_query += f" LIMIT 10"
cte_query = "SELECT content FROM sow_chunks"
params = [query_embeddings]
if sow_id is not None:
cte_query += " WHERE sow_id = $2"
params.append(sow_id)

cte_query += " ORDER BY embedding <=> $1"
cte_query += " LIMIT 10"

# Create the semantic ranker query
query = f"""
WITH vector_results AS (
{cte_query}
)
SELECT content, relevance
FROM semantic_reranking('{user_query}', ARRAY(SELECT content from vector_results))
FROM semantic_reranking($1, ARRAY(SELECT content from vector_results))
ORDER BY relevance DESC
LIMIT {max_results};
"""

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(query, user_query, *params)
return [dict(row) for row in rows]

async def find_sow_validation_results(self, user_query: str, sow_id: int = None):
Expand All @@ -358,18 +397,20 @@ async def find_sow_validation_results(self, user_query: str, sow_id: int = None)

# Get the embeddings for the user query
columns.append(f"""CASE
WHEN result ILIKE '%{user_query}%' THEN 0
ELSE (embedding <=> '{query_embeddings}')::real
WHEN result ILIKE $1 THEN 0
ELSE (embedding <=> $2)::real
END as rank""")

# Use hybrid search to rank records, with exact matches ranked highest
columns.append(f"(embedding <=> '{query_embeddings}')::real as rank")
columns.append(f"(embedding <=> $2)::real as rank")

query = f'SELECT {", ".join(columns)} FROM sow_validation_results'
params = [f'%{user_query}%', query_embeddings]
if sow_id is not None:
query += f' WHERE sow_id = {sow_id}'
query += ' WHERE sow_id = $3'
params.append(sow_id)

query += f' ORDER BY rank ASC'
query += ' ORDER BY rank ASC'

rows = await self.__execute_query(f'{query};')
rows = await self.__execute_query(f'{query};', *params)
return [dict(row) for row in rows]