-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
114 lines (100 loc) · 3.58 KB
/
server.py
File metadata and controls
114 lines (100 loc) · 3.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
'''
asyncio module let us handle asynchronous tasks.
connected_clients set keeps track of all connected clients.
Each time a new client joins, the handle_client function adds them to that set, listens for anything they send and then passes their messages along to everyone else.
Main function sets up the localhost at a port ... and waits for people to connect
Call asyncio.run(main()) to start the server when the script is executed directly.
'''
import asyncio
import websockets
import jwt
from datetime import datetime, timedelta
from sqlalchemy import create_engine, Integer, String, Column, DateTime
from sqlalchemy.orm import declarative_base, sessionmaker
# ---------------------------
# Config / DB setup
# ---------------------------
SECRET_KEY = "supersecretkey"
engine = create_engine("sqlite:///messages.db", echo=True)
Base = declarative_base()
SessionLocal = sessionmaker(bind=engine)
# ---------------------------
# Database model
# ---------------------------
class Message(Base):
__tablename__ = "messages"
id = Column(Integer, primary_key=True)
username = Column(String, nullable=False)
content = Column(String, nullable=False)
timestamp = Column(DateTime, default=datetime.now)
# create tables
Base.metadata.create_all(engine)
connected_clients = {}
# ---------------------------
# JWT helper functions
# ---------------------------
def verify_token(token):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
return payload["username"]
except:
return None
# ---------------------------
# WebSocket handler
# ---------------------------
async def handle_client(websocket):
# Extract JWT from query string: ws://localhost:5000/?token=...
path = websocket.request.path
token = None
if "?" in path:
query = path.split("?")[1]
for kv in query.split("&"):
if "=" in kv:
k, v = kv.split("=")
if k == "token":
token = v
break
username = verify_token(token)
if not username:
await websocket.close()
return
# Add websocket -> username mapping
connected_clients[websocket] = username
# Send last 10 messages from DB to the client
db = SessionLocal()
messages = db.query(Message).order_by(Message.timestamp.desc()).limit(10).all()
for msg in reversed(messages):
await websocket.send(f"{msg.username}: {msg.content}")
try:
async for message_text in websocket:
try:
# Save message to DB
new_message = Message(username=username, content=message_text)
db.add(new_message)
db.commit()
except Exception as e:
db.rollback()
print("DB error:",e)
# Broadcast to all connected clients
send_tasks = []
for client in list(connected_clients):
send_tasks.append(client.send(f"{username}: {message_text}"))
await asyncio.gather(*send_tasks)
finally:
# Remove client
connected_clients.pop(websocket, None)
db.close()
# ---------------------------
# Main server function
# ---------------------------
async def main():
async with websockets.serve(handle_client,"localhost",5000,ping_interval=30,ping_timeout=30):
print("Server started on ws://localhost:5000")
await asyncio.Future() # run forever
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Server is shutting down")
finally:
print("Clean up complete")