11"""
2- FastAPI dependencies for authentication, database, etc.
2+ FastAPI dependencies for authentication, database, and common utilities.
3+
4+ This module uses modern FastAPI 0.124+ patterns including:
5+ - Annotated type hints with Doc for better documentation
6+ - Dependency scopes for proper resource management
7+ - Enhanced type safety with Pydantic 2.10+
38"""
49from typing import Optional , Annotated , AsyncGenerator
510
6- from fastapi import Depends , HTTPException , Header , Request
11+ from annotated_doc import Doc
12+ from fastapi import Depends , HTTPException , Header , Request , status
713from sqlalchemy .ext .asyncio import AsyncSession
814import structlog
915
1319logger = structlog .get_logger ()
1420
1521
22+ # Type aliases for cleaner code
23+ APIKey = Annotated [str , Doc ("Valid API key for authentication" )]
24+ OptionalAPIKey = Annotated [Optional [str ], Doc ("Optional API key from headers" )]
25+
26+
1627async def get_db () -> AsyncGenerator [AsyncSession , None ]:
17- """Get database session dependency."""
28+ """
29+ Get database session dependency.
30+
31+ Uses FastAPI's dependency injection to provide database sessions
32+ that are automatically closed after the request completes.
33+ """
1834 async for session in get_session ():
1935 yield session
2036
2137
38+ # Create typed dependency for database session
39+ DatabaseSession = Annotated [
40+ AsyncSession ,
41+ Depends (get_db ),
42+ Doc ("Async database session for database operations" )
43+ ]
44+
45+
2246async def get_api_key (
23- x_api_key : Annotated [Optional [str ], Header ()] = None ,
24- authorization : Annotated [Optional [str ], Header ()] = None ,
47+ x_api_key : Annotated [
48+ Optional [str ],
49+ Header (
50+ alias = "X-API-Key" ,
51+ description = "API key for authentication" ,
52+ example = "rnd_live_abcdef123456789"
53+ )
54+ ] = None ,
55+ authorization : Annotated [
56+ Optional [str ],
57+ Header (
58+ description = "Bearer token authorization" ,
59+ example = "Bearer rnd_live_abcdef123456789"
60+ )
61+ ] = None ,
2562) -> Optional [str ]:
26- """Extract API key from headers."""
63+ """
64+ Extract API key from request headers.
65+
66+ Supports two authentication methods:
67+ 1. X-API-Key header: Direct API key
68+ 2. Authorization header: Bearer token format
69+ """
2770 if x_api_key :
2871 return x_api_key
29-
72+
3073 if authorization and authorization .startswith ("Bearer " ):
3174 return authorization [7 :]
32-
75+
3376 return None
3477
3578
3679async def require_api_key (
3780 request : Request ,
38- api_key : Optional [str ] = Depends (get_api_key ),
39- db : AsyncSession = Depends (get_db ),
81+ api_key : Annotated [
82+ Optional [str ],
83+ Depends (get_api_key ),
84+ Doc ("API key extracted from request headers" )
85+ ] = None ,
86+ db : DatabaseSession = None ,
4087) -> str :
41- """Require valid API key for endpoint access."""
88+ """
89+ Require valid API key for endpoint access.
90+
91+ This dependency:
92+ - Validates API key format and existence
93+ - Uses timing attack protection
94+ - Supports IP whitelist validation
95+ - Updates API key usage statistics
96+
97+ Returns:
98+ str: Validated API key
99+
100+ Raises:
101+ HTTPException: 401 if API key is missing or invalid
102+ HTTPException: 403 if IP is not in whitelist
103+ """
42104 if not settings .ENABLE_API_KEYS :
43105 return "anonymous"
44-
106+
45107 if not api_key :
46108 raise HTTPException (
47- status_code = 401 ,
48- detail = "API key required" ,
109+ status_code = status .HTTP_401_UNAUTHORIZED ,
110+ detail = {
111+ "error" : "authentication_required" ,
112+ "message" : "API key required" ,
113+ "help" : "Include X-API-Key header or Authorization: Bearer <key>"
114+ },
49115 headers = {"WWW-Authenticate" : "Bearer" },
50116 )
51-
117+
52118 # Validate API key against database with timing attack protection
53119 import asyncio
54120 from api .services .api_key import APIKeyService
55-
121+
56122 # Always take the same amount of time regardless of key validity
57123 start_time = asyncio .get_event_loop ().time ()
58-
124+
59125 api_key_model = await APIKeyService .validate_api_key (
60126 db , api_key , update_usage = True
61127 )
62-
128+
63129 # Ensure constant time execution (minimum 100ms)
64130 elapsed = asyncio .get_event_loop ().time () - start_time
65131 min_time = 0.1 # 100ms
66132 if elapsed < min_time :
67133 await asyncio .sleep (min_time - elapsed )
68-
134+
69135 if not api_key_model :
70136 logger .warning (
71137 "Invalid API key attempted" ,
72138 api_key_prefix = api_key [:8 ] + "..." if len (api_key ) > 8 else api_key ,
73- client_ip = request .client .host ,
139+ client_ip = request .client .host if request . client else "unknown" ,
74140 )
75141 raise HTTPException (
76- status_code = 401 ,
77- detail = "Invalid API key" ,
142+ status_code = status .HTTP_401_UNAUTHORIZED ,
143+ detail = {
144+ "error" : "invalid_api_key" ,
145+ "message" : "Invalid API key"
146+ },
78147 )
79-
148+
80149 # Check IP whitelist if enabled
81150 if settings .ENABLE_IP_WHITELIST :
82151 import ipaddress
83- client_ip = request .client .host
84-
152+ client_ip = request .client .host if request . client else "unknown"
153+
85154 # Validate client IP against CIDR ranges
86- client_ip_obj = ipaddress .ip_address (client_ip )
87- allowed = False
88-
89- for allowed_range in settings .ip_whitelist_parsed :
90- try :
91- if client_ip_obj in ipaddress .ip_network (allowed_range , strict = False ):
92- allowed = True
93- break
94- except (ipaddress .AddressValueError , ipaddress .NetmaskValueError ):
95- # Fallback to string comparison for invalid CIDR
96- if client_ip .startswith (allowed_range ):
97- allowed = True
98- break
99-
100- if not allowed :
101- logger .warning (
102- "IP not in whitelist" ,
103- client_ip = client_ip ,
104- api_key_id = str (api_key_model .id ),
105- user_id = api_key_model .user_id ,
106- )
107- raise HTTPException (
108- status_code = 403 ,
109- detail = "IP address not authorized" ,
110- )
111-
155+ try :
156+ client_ip_obj = ipaddress .ip_address (client_ip )
157+ allowed = False
158+
159+ for allowed_range in settings .ip_whitelist_parsed :
160+ try :
161+ if client_ip_obj in ipaddress .ip_network (allowed_range , strict = False ):
162+ allowed = True
163+ break
164+ except (ipaddress .AddressValueError , ipaddress .NetmaskValueError ):
165+ # Fallback to string comparison for invalid CIDR
166+ if client_ip .startswith (allowed_range ):
167+ allowed = True
168+ break
169+
170+ if not allowed :
171+ logger .warning (
172+ "IP not in whitelist" ,
173+ client_ip = client_ip ,
174+ api_key_id = str (api_key_model .id ),
175+ user_id = api_key_model .user_id ,
176+ )
177+ raise HTTPException (
178+ status_code = status .HTTP_403_FORBIDDEN ,
179+ detail = {
180+ "error" : "ip_not_authorized" ,
181+ "message" : "IP address not authorized"
182+ },
183+ )
184+ except ValueError :
185+ # Invalid IP address format
186+ pass
187+
112188 # Store API key model in request state for other endpoints
113189 request .state .api_key_model = api_key_model
114-
190+
115191 return api_key
116192
117193
194+ # Create typed dependency for API key requirement
195+ RequiredAPIKey = Annotated [
196+ str ,
197+ Depends (require_api_key ),
198+ Doc ("Validated API key from request" )
199+ ]
200+
201+
118202async def get_current_user (
119203 request : Request ,
120- api_key : str = Depends ( require_api_key ) ,
204+ api_key : RequiredAPIKey ,
121205) -> dict :
122- """Get current user from validated API key."""
206+ """
207+ Get current user information from validated API key.
208+
209+ Returns a dictionary containing user details, quotas, and usage statistics.
210+ """
123211 # Get API key model from request state (set by require_api_key)
124212 api_key_model = getattr (request .state , 'api_key_model' , None )
125-
213+
126214 if not api_key_model :
127215 # Fallback for anonymous access
128216 return {
@@ -134,7 +222,7 @@ async def get_current_user(
134222 "monthly_minutes" : 100 ,
135223 },
136224 }
137-
225+
138226 return {
139227 "id" : api_key_model .user_id or f"api_key_{ api_key_model .id } " ,
140228 "api_key_id" : str (api_key_model .id ),
@@ -152,4 +240,27 @@ async def get_current_user(
152240 },
153241 "expires_at" : api_key_model .expires_at .isoformat () if api_key_model .expires_at else None ,
154242 "is_admin" : api_key_model .is_admin ,
155- }
243+ }
244+
245+
246+ # Create typed dependency for current user
247+ CurrentUser = Annotated [
248+ dict ,
249+ Depends (get_current_user ),
250+ Doc ("Current authenticated user information" )
251+ ]
252+
253+
254+ # Optional API key dependency (doesn't require authentication)
255+ async def get_optional_api_key (
256+ api_key : Annotated [Optional [str ], Depends (get_api_key )] = None ,
257+ ) -> Optional [str ]:
258+ """Get API key if provided, without requiring it."""
259+ return api_key
260+
261+
262+ OptionalAuth = Annotated [
263+ Optional [str ],
264+ Depends (get_optional_api_key ),
265+ Doc ("Optional API key for endpoints that support anonymous access" )
266+ ]
0 commit comments