-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapp2.py
More file actions
215 lines (178 loc) · 6.33 KB
/
app2.py
File metadata and controls
215 lines (178 loc) · 6.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import json
import os
from contextlib import contextmanager
from decimal import Decimal
import psycopg2
import requests
from dotenv import load_dotenv
from flask import Flask, jsonify, request
from flask_cors import CORS
# Load environment variables
load_dotenv()
# Flask app setup
app = Flask(__name__)
CORS(app)
# Gemini API Config
API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
headers = {"Content-Type": "application/json"}
# DB Config
DB_CONFIG = {
"dbname": os.getenv("DB_NAME"),
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"host": os.getenv("DB_HOST"),
"port": os.getenv("DB_PORT"),
}
# PostgreSQL connection context manager
@contextmanager
def get_connection():
conn = None
cur = None
try:
conn = psycopg2.connect(os.getenv("DB_URI"))
cur = conn.cursor()
yield conn, cur
finally:
if cur:
cur.close()
if conn:
conn.close()
# Fetch database schema
def get_schema():
try:
with get_connection() as (conn, cur):
cur.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
"""
)
tables = cur.fetchall()
schema_text = "CREATE TABLE statements:\n\n"
for table in tables:
table_name = table[0]
cur.execute(
f"""
SELECT
c.column_name,
c.data_type,
c.is_nullable,
c.column_default,
CASE
WHEN tc.constraint_type = 'PRIMARY KEY' THEN 'PRIMARY KEY'
ELSE ''
END as key_type
FROM information_schema.columns c
LEFT JOIN information_schema.table_constraints tc
ON c.table_name = tc.table_name
AND tc.constraint_type = 'PRIMARY KEY'
WHERE c.table_name = '{table_name}'
ORDER BY c.ordinal_position;
"""
)
columns = cur.fetchall()
create_table = f"CREATE TABLE {table_name} (\n"
for col in columns:
name, data_type, nullable, default, key_type = col
create_table += f" {name} {data_type}"
if key_type:
create_table += f" {key_type}"
if nullable == "NO":
create_table += " NOT NULL"
if default:
create_table += f" DEFAULT {default}"
create_table += ",\n"
create_table = create_table.rstrip(",\n") + "\n);\n\n"
schema_text += create_table
return schema_text
except Exception as e:
print(f"Error fetching schema: {e}")
return None
# Generate SQL query from Gemini
def get_sql_query(user_query, schema):
try:
prompt = f"""### PostgreSQL database schema:
{schema}
### Instructions:
- Generate a single, correct SQL query that answers the user's question.
- Include appropriate table aliases, JOIN conditions, and WHERE clauses.
- Provide the SQL query only, with no additional explanations or questions.
### User Query: {user_query}
### SQL Query:"""
payload = {"contents": [{"parts": [{"text": prompt}]}]}
response = requests.post(
f"{API_URL}?key={GEMINI_API_KEY}",
headers=headers,
json=payload,
)
if response.status_code != 200:
print(f"Error from API: {response.text}")
return None
response_data = response.json()
if "candidates" not in response_data or not response_data["candidates"]:
print("No candidates in response")
return None
sql_query = response_data["candidates"][0]["content"]["parts"][0][
"text"
].strip()
if sql_query.startswith("```sql") and sql_query.endswith("```"):
sql_query = sql_query[7:-3].strip()
if not sql_query.endswith(";"):
sql_query += ";"
return sql_query
except Exception as e:
print(f"Error generating SQL: {e}")
return None
# Execute SQL query and return results
def execute_query(sql_query):
try:
with get_connection() as (conn, cur):
cur.execute(sql_query)
if sql_query.strip().upper().startswith("SELECT"):
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
results = [
{
columns[i]: (
float(value) if isinstance(value, Decimal) else value
)
for i, value in enumerate(row)
}
for row in rows
]
else:
conn.commit()
results = {"message": "Query executed successfully"}
return results
except Exception as e:
print(f"Error executing query: {e}")
return None
# API endpoint for user queries
@app.route("/query", methods=["POST"])
def query():
try:
data = request.get_json()
user_query = data.get("query")
if not user_query:
return jsonify({"error": "Query is required"}), 400
schema = get_schema()
if not schema:
return jsonify({"error": "Failed to fetch schema"}), 500
sql_query = get_sql_query(user_query, schema)
if not sql_query:
return jsonify({"error": "Failed to generate SQL query"}), 500
results = execute_query(sql_query)
if results is None:
return jsonify({"error": "Failed to execute query"}), 500
return jsonify({"sql_query": sql_query, "results": results})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Health check endpoint
@app.route("/health", methods=["GET"])
def health():
return jsonify({"status": "ok"})
# Run Flask app
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=4000)