-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathdbSQLAlchemy.py
More file actions
115 lines (89 loc) · 3.96 KB
/
dbSQLAlchemy.py
File metadata and controls
115 lines (89 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from sqlalchemy import Column, String, DateTime, create_engine, Integer, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from datetime import datetime
#connect to db or create if not exists
engine = create_engine('sqlite:///my_database.db', echo=True)
Base = declarative_base()
class User(Base):
__tablename__ = "users"
username = Column(String, primary_key=True)
session_id = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
class ChatMessage(Base):
__tablename__ = "chat_messages"
id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String, ForeignKey("users.username"), nullable=False)
role = Column(String, nullable=False)
content = Column(String, nullable=False)
created_at = Column(DateTime, default=datetime.utcnow)
class UserPreferences(Base):
__tablename__ = "user_preferences"
username = Column(String, ForeignKey("users.username"), primary_key=True)
preferences_json = Column(String, nullable=False)
updated_at = Column(DateTime, default=datetime.utcnow)
age = Column(Integer)
gender = Column(String)
pdf_text = Column(String, nullable=True)
# Create tables
Base.metadata.create_all(bind=engine)
#Session factory to interact with db
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
"""FastAPI dependency that yields a per-request SQLAlchemy session.
A previous module-level `session = SessionLocal()` was shared across
requests, so a single failed transaction (e.g. a raised HTTPException
mid-write) wedged the session and broke every subsequent request with
'This transaction is closed'. Per-request sessions avoid that entirely."""
db_sess = SessionLocal()
try:
yield db_sess
finally:
db_sess.close()
# CRUD helpers
def get_user(db_session, username: str):
return db_session.query(User).filter_by(username=username).first()
def create_user(db_session, username: str, session_id: str):
user = User(username=username, session_id=session_id)
db_session.add(user)
db_session.commit()
db_session.refresh(user)
return user
def update_session(db_session, username: str, session_id: str):
user = get_user(db_session, username)
if user:
user.session_id = session_id
db_session.commit()
db_session.refresh(user)
return username
def create_chat_message(db_session, username: str, role: str, content: str):
chat_message = ChatMessage(username=username, role=role, content=content)
db_session.add(chat_message)
db_session.commit()
db_session.refresh(chat_message)
def create_user_preferences(db_session, username: str, preferences_json: str):
user_preferences = UserPreferences(username=username, preferences_json=preferences_json)
db_session.add(user_preferences)
db_session.commit()
db_session.refresh(user_preferences)
return user_preferences
def get_user_preferences(db_session, username: str):
return db_session.query(UserPreferences).filter_by(username=username).first()
def update_user_preferences(db_session, username: str, preferences_json: str):
user_preferences = get_user_preferences(db_session, username)
if user_preferences:
user_preferences.preferences_json = preferences_json
db_session.commit()
db_session.refresh(user_preferences)
return user_preferences
return None
def get_chat_messages(db_session, username: str):
return db_session.query(ChatMessage).filter_by(username=username).order_by(ChatMessage.created_at.asc()).all()
def update_user_pdf_text(db_session, username: str, pdf_text: str):
user_preferences = get_user_preferences(db_session, username)
if user_preferences:
user_preferences.pdf_text = pdf_text
db_session.commit()
db_session.refresh(user_preferences)
return user_preferences
return None