-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchat.py
More file actions
257 lines (231 loc) · 8.5 KB
/
chat.py
File metadata and controls
257 lines (231 loc) · 8.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import argparse
import json
import torch
import transformers
from transformers import TextStreamer
from utils import extract_sql, extract_think
from utils.constraints import (
COLORED_BLUE, COLORED_GREEN, COLORED_RESET,
REASONING_START, REASONING_END, SOLUTION_START, SOLUTION_END
)
parser = argparse.ArgumentParser(
prog="uv run chat.py",
description=(
"Chatbot for SQL generation and reasoning.\n"
"This program uses a language model to generate SQL queries based on user prompts.\n"
"It supports multiple database backends (DuckDB, SQLite, PostgreSQL) and allows for SQL execution.\n"
),
epilog="Developed by Protons · GitHub: https://github.com/prodesk98",
)
parser.add_argument(
"--model",
type=str,
default="proton98/sql-llama3.2-3b-it-reasoning",
help="Model name. Default is 'proton98/sql-llama3.2-3b-it-reasoning'.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=1024,
help="Maximum number of new tokens to generate. Default is 1024.",
)
parser.add_argument(
"--db-uri",
type=str,
default=":memory:",
help="Database URI. Default is ':memory:'.",
)
parser.add_argument(
"--db-driver",
type=str,
default="duckdb",
choices=["duckdb", "sqlite", "postgresql"],
help="Database driver to use. Default is 'duckdb'.",
)
args = parser.parse_args()
if args.db_driver == "duckdb":
try:
import duckdb
except ImportError:
raise ImportError(
"DuckDB driver is not installed. Please install it using 'pip install duckdb'."
)
db = duckdb.connect(args.db_uri)
elif args.db_driver == "sqlite":
try:
import sqlite3
except ImportError:
raise ImportError(
"SQLite driver is not installed. Please install it using 'pip install sqlite3'."
)
db = sqlite3.connect(args.db_uri)
elif args.db_driver == "postgresql":
try:
import psycopg2
except ImportError:
raise ImportError(
"PostgreSQL driver is not installed. Please install it using 'pip install psycopg2'."
)
db = psycopg2.connect(args.db_uri)
else:
raise ValueError(
f"Unsupported database driver: {args.db_driver}. Supported drivers are 'duckdb', 'sqlite' and 'postgresql'."
)
def get_context_schemas() -> dict | None:
"""
Get the full schema (tables, columns, data types, and constraints) from the database.
:return: Dictionary structured schema
"""
try:
schema = {}
if args.db_driver == "duckdb":
tables = db.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'").fetchall() # noqa
for (table,) in tables:
columns = db.execute(f"""SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}'""").fetchall() # noqa
schema[table] = [{"name": col, "type": dtype} for col, dtype in columns]
elif args.db_driver == "sqlite":
tables = db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'").fetchall() # noqa
for (table,) in tables:
columns = db.execute(f"PRAGMA table_info('{table}')").fetchall()
schema[table] = [{
"name": col[1],
"type": col[2],
"primary_key": bool(col[5])
} for col in columns]
elif args.db_driver == "postgresql":
cursor = db.cursor()
cursor.execute("""SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'""") # noqa
tables = cursor.fetchall()
for (table,) in tables:
cursor.execute(f"""SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table}'""") # noqa
columns = cursor.fetchall()
# Fetch primary keys
cursor.execute(f"""SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) WHERE i.indrelid = '{table}'::regclass AND i.indisprimary""") # noqa
pk_columns = {col[0] for col in cursor.fetchall()}
schema[table] = [{
"name": col[0],
"type": col[1],
"primary_key": col[0] in pk_columns
} for col in columns]
else:
raise ValueError(
f"Unsupported database driver: {args.db_driver}. Supported drivers are 'duckdb', 'sqlite' and 'postgresql'."
)
return schema
except Exception as e: # noqa
print(f"Error getting context schema: {e}")
return None
def execute_sql(sql_query: str):
"""
Executes the SQL query and returns the result.
:param sql_query:
:return:
"""
try:
if args.db_driver == "duckdb":
return db.execute(sql_query).fetchall()
elif args.db_driver == "sqlite":
return db.execute(sql_query).fetchall()
elif args.db_driver == "postgresql":
return db.cursor().execute(sql_query).fetchall()
else:
raise ValueError(
f"Unsupported database driver: {args.db_driver}. Supported drivers are 'duckdb', 'sqlite' and 'postgresql'."
)
except Exception as e: # noqa
print(f"Error executing SQL: {e}")
return None
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model,
use_fast=True,
padding_side="left",
trust_remote_code=True,
)
streamer = TextStreamer(tokenizer, skip_prompt=True)
pipeline = transformers.pipeline(
"text-generation",
model=args.model,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
streamer=streamer,
pad_token_id=tokenizer.eos_token_id,
)
last_prompt: str | None = None
last_sql: str | None = None
while True:
try:
context: str | None = json.dumps(get_context_schemas())
prompt = input("Prompt: ")
command = prompt.lower()
if command == "/exit":
break
elif command == "/retry":
if last_prompt:
prompt = last_prompt
else:
print("No last prompt found.")
elif command == "/execute":
if last_sql:
result = execute_sql(last_sql)
print(f"SQL Result: {result}")
else:
print("No SQL query found to execute.")
continue
elif command == "help" or command == "/help":
print(
"Available commands:\n"
"/exit - Exit the program\n"
"/retry - Retry the last prompt\n"
"/execute - Execute the last SQL query\n"
"/context - Show the context of the database\n"
"/clear - Clear the context\n"
"/last - Show the last SQL query\n"
"/help - Show this help message"
)
continue
elif command == "/context":
print(context)
continue
elif command == "/clear":
print("Clearing context...")
db.execute("DROP TABLE IF EXISTS context") # noqa
continue
elif command == "/last":
if last_sql:
print(f"Last SQL: {last_sql}")
else:
print("No last SQL found.")
continue
messages = [
{
"role": "system",
"content": f"You are an expert in writing optimized SQL queries.\n"
f"Think about the problem and provide your working out.\n"
f"Place it between {REASONING_START} and {REASONING_END}.\n"
f"Then, provide your solution between {SOLUTION_START}{SOLUTION_END}\n\n"
f"Context: {context}"
},
{"role": "user", "content": prompt},
]
outputs = pipeline(
messages,
max_new_tokens=args.max_new_tokens,
)
content = outputs[0]["generated_text"][-1]['content']
sql = extract_sql(content)
think = extract_think(content)
if sql is None:
print("No SQL query found in the output.")
continue
print(
f"{COLORED_BLUE}{think}{COLORED_RESET}\n"
f"{COLORED_GREEN}{sql}{COLORED_RESET}"
)
last_prompt = prompt
last_sql = sql
except KeyboardInterrupt:
print("Exiting...")
break
except Exception as e:
print(f"Error: {e}")
continue