-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathagent.py
More file actions
79 lines (62 loc) · 2.68 KB
/
agent.py
File metadata and controls
79 lines (62 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""This module contains SQLAgent Class."""
import os
from config import EXAMPLE, Config
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_core.runnables import RunnableConfig
from langchain.cache import SQLiteCache
from langchain.globals import set_llm_cache
from typing import Any, Dict, Optional
from agent_utils.extra_tools import visualize_data
def set_env_vars(env_file_path='.env'):
"""Read the file and set environment variables."""
with open(env_file_path, 'r') as file:
for line in file:
key, value = line.strip().split('=', 1)
if not value:
raise ValueError(
f"Invalid value for environment variable: {key}\n Please double check the .env file.")
os.environ[key] = value
set_env_vars('.env')
set_llm_cache(SQLiteCache(database_path="data/.langchain.db"))
class SQLAgent():
"""
Agent to convert user queries into sql queries and process results.
Attributes:
config (Config): The configuration settings for the agent.
"""
def __init__(self, config: Config):
"""
Initializes the SQLAgent.
Args:
config (Config): The configuration for processing queries.
"""
self.agent_executor = create_sql_agent(llm=config.llm,
toolkit=config.sql_connector,
agent_type=config.agent_type,
max_iterations=10,
prompt=config.prompt,
verbose=True,
extra_tools=[
visualize_data] if config.enable_chart else [],
agent_executor_kwargs={
"return_intermediate_steps": True,
"handle_parsing_errors": True,
})
def run(self, user_query: str,
st_callback: Optional[StreamlitCallbackHandler] = None) -> Dict[str, Any]:
"""
Processes a user query and returns a list of processing steps.
Args:
user_query (str): The user query to process.
Returns:
Result of the agent invocation.
"""
if st_callback:
return self.agent_executor.invoke(user_query, config=RunnableConfig(callbacks=[st_callback]))
return self.agent_executor.invoke(user_query)
# DEMO
if __name__ == "__main__":
demo_config = Config.create_custom_openai_custom_sqllite_with_chart()
demo_sql_agent = SQLAgent(config=demo_config)
demo_sql_agent.run(EXAMPLE.query)