22Authentication and authorization utilities.
33"""
44
5- import secrets
65import hashlib
6+ import secrets
7+ from collections import defaultdict
78from datetime import datetime , timedelta , timezone
8- from typing import Any , Dict , Optional , Union
99from enum import Enum
10+ from time import time
11+ from typing import Any , Dict , Optional , Union
1012
13+ from fastapi import Depends , HTTPException , status
14+ from fastapi .security import HTTPAuthorizationCredentials , HTTPBearer
1115from jose import jwt
1216from passlib .context import CryptContext
13- from fastapi import HTTPException , status , Depends
14- from fastapi .security import HTTPBearer , HTTPAuthorizationCredentials
1517from sqlalchemy .ext .asyncio import AsyncSession
1618
1719from src .config import get_settings
2729
2830class UserRole (str , Enum ):
2931 """User roles for authorization."""
32+
3033 ADMIN = "admin"
3134 USER = "user"
3235 GUEST = "guest"
3336
3437
3538class TokenType (str , Enum ):
3639 """Token types."""
40+
3741 ACCESS = "access"
3842 REFRESH = "refresh"
3943 EMAIL_VERIFY = "email_verify"
@@ -43,24 +47,28 @@ class TokenType(str, Enum):
4347# Exception classes
4448class AuthenticationError (HTTPException ):
4549 """Authentication failed."""
50+
4651 def __init__ (self , detail : str = "Authentication failed" ):
4752 super ().__init__ (status_code = status .HTTP_401_UNAUTHORIZED , detail = detail )
4853
4954
5055class AuthorizationError (HTTPException ):
5156 """Authorization failed."""
57+
5258 def __init__ (self , detail : str = "Insufficient permissions" ):
5359 super ().__init__ (status_code = status .HTTP_403_FORBIDDEN , detail = detail )
5460
5561
5662class TokenExpiredError (AuthenticationError ):
5763 """Token has expired."""
64+
5865 def __init__ (self , detail : str = "Token has expired" ):
5966 super ().__init__ (detail = detail )
6067
6168
6269class InvalidTokenError (AuthenticationError ):
6370 """Token is invalid."""
71+
6472 def __init__ (self , detail : str = "Invalid token" ):
6573 super ().__init__ (detail = detail )
6674
@@ -87,67 +95,51 @@ def create_access_token(
8795 expires_delta : Optional [timedelta ] = None ,
8896 user_id : Optional [int ] = None ,
8997 role : Optional [str ] = None ,
90- token_type : TokenType = TokenType .ACCESS
98+ token_type : TokenType = TokenType .ACCESS ,
9199) -> str :
92100 """Create a JWT access token."""
93101 if expires_delta :
94102 expire = datetime .now (timezone .utc ) + expires_delta
95103 else :
96- expire = datetime .now (timezone .utc ) + timedelta (
97- minutes = settings .access_token_expire_minutes
98- )
99-
100- to_encode = {
101- "exp" : expire ,
102- "iat" : datetime .now (timezone .utc ),
103- "sub" : str (subject ),
104- "type" : token_type .value
105- }
106-
104+ expire = datetime .now (timezone .utc ) + timedelta (minutes = settings .access_token_expire_minutes )
105+
106+ to_encode = {"exp" : expire , "iat" : datetime .now (timezone .utc ), "sub" : str (subject ), "type" : token_type .value }
107+
107108 if user_id :
108109 to_encode ["user_id" ] = user_id
109110 if role :
110111 to_encode ["role" ] = role
111-
112+
112113 return jwt .encode (to_encode , settings .secret_key , algorithm = settings .algorithm )
113114
114115
115- def create_refresh_token (
116- subject : Union [str , Any ],
117- user_id : Optional [int ] = None
118- ) -> str :
116+ def create_refresh_token (subject : Union [str , Any ], user_id : Optional [int ] = None ) -> str :
119117 """Create a JWT refresh token."""
120- expire = datetime .now (timezone .utc ) + timedelta (
121- days = settings .refresh_token_expire_days
122- )
123-
118+ expire = datetime .now (timezone .utc ) + timedelta (days = settings .refresh_token_expire_days )
119+
124120 to_encode = {
125121 "exp" : expire ,
126122 "iat" : datetime .now (timezone .utc ),
127123 "sub" : str (subject ),
128124 "user_id" : user_id ,
129125 "type" : TokenType .REFRESH .value ,
130- "jti" : generate_secure_token (16 ) # JWT ID for invalidation
126+ "jti" : generate_secure_token (16 ), # JWT ID for invalidation
131127 }
132-
128+
133129 return jwt .encode (to_encode , settings .secret_key , algorithm = settings .algorithm )
134130
135131
136132def verify_token (token : str , token_type : Optional [TokenType ] = None ) -> Dict [str , Any ]:
137133 """Verify and decode a JWT token."""
138134 try :
139- payload = jwt .decode (
140- token ,
141- settings .secret_key ,
142- algorithms = [settings .algorithm ]
143- )
144-
135+ payload = jwt .decode (token , settings .secret_key , algorithms = [settings .algorithm ])
136+
145137 # Check token type if specified
146138 if token_type and payload .get ("type" ) != token_type .value :
147139 raise InvalidTokenError ("Invalid token type" )
148-
140+
149141 return payload
150-
142+
151143 except jwt .ExpiredSignatureError :
152144 raise TokenExpiredError ()
153145 except jwt .InvalidTokenError :
@@ -172,68 +164,66 @@ def verify_api_key(api_key: str, hashed_key: str) -> bool:
172164
173165# Authentication dependencies
174166async def get_current_user_from_token (
175- credentials : HTTPAuthorizationCredentials = Depends (security ),
176- db : AsyncSession = Depends (get_async_db )
167+ credentials : HTTPAuthorizationCredentials = Depends (security ), db : AsyncSession = Depends (get_async_db )
177168):
178169 """Get current user from JWT token."""
179170 from src .models .user import User # Import here to avoid circular imports
180-
171+
181172 try :
182173 payload = verify_token (credentials .credentials , TokenType .ACCESS )
183174 user_id : int = payload .get ("user_id" )
184-
175+
185176 if user_id is None :
186177 raise InvalidTokenError ("Token missing user_id" )
187-
178+
188179 user = await User .get (db , user_id )
189180 if user is None :
190181 raise AuthenticationError ("User not found" )
191-
182+
192183 if not user .is_active :
193184 raise AuthenticationError ("User account is disabled" )
194-
185+
195186 return user
196-
187+
197188 except (InvalidTokenError , TokenExpiredError , AuthenticationError ):
198189 raise
199190 except Exception as e :
200191 raise AuthenticationError (f"Authentication failed: { str (e )} " )
201192
202193
203194async def get_current_user_from_api_key (
204- credentials : HTTPAuthorizationCredentials = Depends (security ),
205- db : AsyncSession = Depends (get_async_db )
195+ credentials : HTTPAuthorizationCredentials = Depends (security ), db : AsyncSession = Depends (get_async_db )
206196):
207197 """Get current user from API key."""
208- from src .models .user import User , APIKey # Import here to avoid circular imports
209-
198+ from src .models .user import APIKey , User # Import here to avoid circular imports
199+
210200 try :
211201 api_key = credentials .credentials
212202 if not api_key .startswith ("ws_" ):
213203 raise InvalidTokenError ("Invalid API key format" )
214-
204+
215205 # Hash the provided API key
216206 hashed_key = hash_api_key (api_key )
217-
207+
218208 # Find the API key in database
219209 api_key_obj = await APIKey .get_by_hash (db , hashed_key )
220210 if not api_key_obj or not api_key_obj .is_active :
221211 raise AuthenticationError ("Invalid or inactive API key" )
222-
212+
223213 # Check expiration
224214 if api_key_obj .expires_at and api_key_obj .expires_at < datetime .utcnow ():
225215 raise AuthenticationError ("API key has expired" )
226-
216+
227217 # Get associated user
228218 user = await User .get (db , api_key_obj .user_id )
229219 if not user or not user .is_active :
230220 raise AuthenticationError ("User not found or inactive" )
231-
221+
232222 # Update last used timestamp
233223 await api_key_obj .update (db , last_used_at = datetime .utcnow ())
234-
224+
235225 return user
236-
226+
237227 except (InvalidTokenError , AuthenticationError ):
238228 raise
239229 except Exception as e :
@@ -242,8 +232,7 @@ async def get_current_user_from_api_key(
242232
243233# Flexible authentication - tries both token and API key
244234async def get_current_user (
245- credentials : HTTPAuthorizationCredentials = Depends (security ),
246- db : AsyncSession = Depends (get_async_db )
235+ credentials : HTTPAuthorizationCredentials = Depends (security ), db : AsyncSession = Depends (get_async_db )
247236):
248237 """Get current user from either JWT token or API key."""
249238 try :
@@ -259,10 +248,12 @@ async def get_current_user(
259248
260249def require_role (required_role : UserRole ):
261250 """Dependency to require specific role."""
262- async def role_checker (current_user = Depends (get_current_user )):
251+
252+ async def role_checker (current_user = Depends (get_current_user )):
263253 if current_user .role != required_role and current_user .role != UserRole .ADMIN :
264254 raise AuthorizationError (f"Role '{ required_role } ' required" )
265255 return current_user
256+
266257 return role_checker
267258
268259
@@ -272,61 +263,50 @@ def require_admin():
272263
273264
274265async def get_optional_user (
275- credentials : Optional [HTTPAuthorizationCredentials ] = Depends (
276- HTTPBearer (auto_error = False )
277- ),
278- db : AsyncSession = Depends (get_async_db )
266+ credentials : Optional [HTTPAuthorizationCredentials ] = Depends (HTTPBearer (auto_error = False )),
267+ db : AsyncSession = Depends (get_async_db ),
279268):
280269 """Get current user if authenticated, None otherwise."""
281270 if not credentials :
282271 return None
283-
272+
284273 try :
285274 return await get_current_user (credentials , db )
286275 except (AuthenticationError , AuthorizationError ):
287276 return None
288277
289278
290279# Rate limiting helper (simple in-memory implementation)
291- from collections import defaultdict
292- from time import time
293280
294281_rate_limit_cache = defaultdict (list )
295282
296- def check_rate_limit (
297- identifier : str ,
298- max_requests : int = 100 ,
299- window_seconds : int = 3600
300- ) -> bool :
283+
284+ def check_rate_limit (identifier : str , max_requests : int = 100 , window_seconds : int = 3600 ) -> bool :
301285 """Simple in-memory rate limiting."""
302286 now = time ()
303287 window_start = now - window_seconds
304-
288+
305289 # Clean old entries
306290 _rate_limit_cache [identifier ] = [
307- timestamp for timestamp in _rate_limit_cache [identifier ]
308- if timestamp > window_start
291+ timestamp for timestamp in _rate_limit_cache [identifier ] if timestamp > window_start
309292 ]
310-
293+
311294 # Check if under limit
312295 if len (_rate_limit_cache [identifier ]) >= max_requests :
313296 return False
314-
297+
315298 # Add current request
316299 _rate_limit_cache [identifier ].append (now )
317300 return True
318301
319302
320303def require_rate_limit (max_requests : int = 100 , window_seconds : int = 3600 ):
321304 """Dependency for rate limiting."""
322- async def rate_limiter (
323- credentials : HTTPAuthorizationCredentials = Depends (security )
324- ):
305+
306+ async def rate_limiter (credentials : HTTPAuthorizationCredentials = Depends (security )):
325307 # Use token/api key as identifier
326308 if not check_rate_limit (credentials .credentials , max_requests , window_seconds ):
327- raise HTTPException (
328- status_code = status .HTTP_429_TOO_MANY_REQUESTS ,
329- detail = "Rate limit exceeded"
330- )
309+ raise HTTPException (status_code = status .HTTP_429_TOO_MANY_REQUESTS , detail = "Rate limit exceeded" )
331310 return True
332- return rate_limiter
311+
312+ return rate_limiter
0 commit comments