Skip to content

Commit e5448b9

Browse files
authored
Fixes and optimisations for the Murfey authentication API (#730)
* Forwards only essential headers to the auth server to prevent timeouts due to mismatch between header, body, and methods. * Migrate authentication server querying logic out into a helper function to minimise repetition. * Improved test coverage of the murfey.server.api.auth module.
1 parent 467c2d0 commit e5448b9

2 files changed

Lines changed: 634 additions & 93 deletions

File tree

src/murfey/server/api/auth.py

Lines changed: 108 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import secrets
44
import time
55
from logging import getLogger
6-
from typing import Dict
76
from uuid import uuid4
87

98
import aiohttp
@@ -18,7 +17,7 @@
1817
from passlib.context import CryptContext
1918
from pydantic import BaseModel
2019
from sqlmodel import Session, create_engine, select
21-
from typing_extensions import Annotated
20+
from typing_extensions import Annotated, Any
2221

2322
from murfey.server.murfey_db import murfey_db, url
2423
from murfey.util.api import url_path_for
@@ -40,17 +39,19 @@
4039
auth_url = security_config.auth_url
4140
ALGORITHM = security_config.auth_algorithm or "HS256"
4241
SECRET_KEY = security_config.auth_key or secrets.token_hex(32)
43-
if security_config.auth_type == "password":
44-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
45-
else:
46-
oauth2_scheme = APIKeyCookie(name=security_config.cookie_key)
47-
if security_config.instrument_auth_type == "token":
48-
instrument_oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth/token")
49-
else:
50-
instrument_oauth2_scheme = lambda *args, **kwargs: None
42+
oauth2_scheme = (
43+
OAuth2PasswordBearer(tokenUrl="auth/token")
44+
if security_config.auth_type == "password"
45+
else APIKeyCookie(name=security_config.cookie_key)
46+
)
47+
instrument_oauth2_scheme = (
48+
OAuth2PasswordBearer(tokenUrl="auth/token")
49+
if security_config.instrument_auth_type == "token"
50+
else lambda *args, **kwargs: None
51+
)
5152
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
5253

53-
instrument_server_tokens: Dict[float, dict] = {}
54+
instrument_server_tokens: dict[float, dict] = {}
5455

5556
# Set up database engine
5657
try:
@@ -66,14 +67,30 @@ def hash_password(password: str) -> str:
6667

6768
"""
6869
=======================================================================================
69-
TOKEN VALIDATION FUNCTIONS
70+
VALIDATION FUNCTIONS
7071
=======================================================================================
7172
7273
Functions and helpers used to validate incoming requests from both the client and
73-
the frontend. 'validate_token()' and 'validate_instrument_token()' are imported
74-
int the other FastAPI modules and attached as dependencies to the routers.
74+
the frontend.
75+
76+
'validate_token()' and 'validate_instrument_token()' are imported in the other FastAPI
77+
modules and attached as dependencies to the routers. They validate the tokens passed
78+
around internally by Murfey to ensure that the request is valid.
79+
80+
'validate_instrument_server_session_access()' and 'validate_frontend_session_access()'
81+
are used to verify the IDs of sessions ot be accessed, and are attached as dependencies
82+
to them.
83+
84+
'validate_user_instrument_access()' is used to verify the instrument server name being
85+
accessed by the frontend, and is attached as a dependency as well.
7586
"""
7687

88+
# Essential headers used for authentication to forward along if present
89+
AUTH_HEADERS = (
90+
"authorization",
91+
"x-auth-request-access-token",
92+
)
93+
7794

7895
def check_user(username: str) -> bool:
7996
try:
@@ -84,6 +101,39 @@ def check_user(username: str) -> bool:
84101
return username in [u.username for u in users]
85102

86103

104+
async def submit_to_auth_endpoint(
105+
url_subpath: str,
106+
request: Request,
107+
token: str,
108+
) -> dict[str, Any]:
109+
"""
110+
Helper function to forward incoming requests to an authentication server
111+
to verify that they are allowed to inspect the
112+
"""
113+
114+
# Forward only essentials auth-related headers
115+
headers = {
116+
key: value
117+
for key, value in dict(request.headers).items()
118+
if key.lower() in AUTH_HEADERS
119+
}
120+
if security_config.auth_type == "password":
121+
headers["authorization"] = f"Bearer {token}"
122+
cookies = (
123+
{security_config.cookie_key: token}
124+
if security_config.auth_type == "cookie"
125+
else {}
126+
)
127+
async with aiohttp.ClientSession(cookies=cookies) as session:
128+
async with session.get(
129+
f"{auth_url}/{url_subpath}",
130+
headers=headers,
131+
) as response:
132+
success = response.status == 200
133+
validation_outcome: dict[str, Any] = await response.json()
134+
return validation_outcome if success and validation_outcome else {"valid": False}
135+
136+
87137
async def validate_token(
88138
token: Annotated[str, Depends(oauth2_scheme)],
89139
request: Request,
@@ -94,25 +144,9 @@ async def validate_token(
94144
try:
95145
# Validate using auth URL if provided; will error if invalid
96146
if auth_url:
97-
# Extract and forward headers as-is
98-
headers = dict(request.headers)
99-
# Update/add authorization header if authenticating using password
100-
if security_config.auth_type == "password":
101-
headers["authorization"] = f"Bearer {token}"
102-
# Forward the cookie along if authenticating using cookie
103-
cookies = (
104-
{security_config.cookie_key: token}
105-
if security_config.auth_type == "cookie"
106-
else {}
107-
)
108-
async with aiohttp.ClientSession(cookies=cookies) as session:
109-
async with session.get(
110-
f"{auth_url}/validate_token",
111-
headers=headers,
112-
) as response:
113-
success = response.status == 200
114-
validation_outcome = await response.json()
115-
if not (success and validation_outcome.get("valid")):
147+
if not (
148+
await submit_to_auth_endpoint("validate_token", request, token)
149+
).get("valid"):
116150
raise JWTError
117151
# If authenticating using cookies; an auth URL MUST be provided
118152
else:
@@ -199,20 +233,6 @@ async def validate_instrument_token(
199233
return None
200234

201235

202-
"""
203-
=======================================================================================
204-
SESSION ID VALIDATION
205-
=======================================================================================
206-
207-
Annotated ints are defined here that trigger validation of the session IDs in incoming
208-
requests, verifying that the session is allowed to access the particular visit.
209-
210-
The 'MurfeySessionID...' types are imported and used in the type hints of the endpoint
211-
functions in the other FastAPI routers, depending on whether requests from the frontend
212-
or the instrument are expected.
213-
"""
214-
215-
216236
def get_visit_name(session_id: int) -> str:
217237
with Session(engine) as murfey_db:
218238
return (
@@ -222,46 +242,6 @@ def get_visit_name(session_id: int) -> str:
222242
)
223243

224244

225-
async def submit_to_auth_endpoint(url_subpath: str, token: str) -> None:
226-
if auth_url:
227-
headers = (
228-
{}
229-
if security_config.auth_type == "cookie"
230-
else {"Authorization": f"Bearer {token}"}
231-
)
232-
cookies = (
233-
{security_config.cookie_key: token}
234-
if security_config.auth_type == "cookie"
235-
else {}
236-
)
237-
async with aiohttp.ClientSession(cookies=cookies) as session:
238-
async with session.get(
239-
f"{auth_url}/{url_subpath}",
240-
headers=headers,
241-
) as response:
242-
success = response.status == 200
243-
validation_outcome: dict = await response.json()
244-
if not (success and validation_outcome.get("valid")):
245-
logger.warning("Unauthorised visit access request from frontend")
246-
raise HTTPException(
247-
status_code=status.HTTP_401_UNAUTHORIZED,
248-
detail="You do not have access to this visit",
249-
headers={"WWW-Authenticate": "Bearer"},
250-
)
251-
252-
253-
async def validate_frontend_session_access(
254-
session_id: int,
255-
token: Annotated[str, Depends(oauth2_scheme)],
256-
) -> int:
257-
"""
258-
Validates whether a frontend request can access information about this session
259-
"""
260-
visit_name = get_visit_name(session_id)
261-
await submit_to_auth_endpoint(f"validate_visit_access/{visit_name}", token)
262-
return session_id
263-
264-
265245
async def validate_instrument_server_session_access(
266246
session_id: int,
267247
token: Annotated[str, Depends(instrument_oauth2_scheme)],
@@ -294,25 +274,60 @@ async def validate_instrument_server_session_access(
294274
return session_id
295275

296276

277+
async def validate_frontend_session_access(
278+
session_id: int,
279+
request: Request,
280+
token: Annotated[str, Depends(oauth2_scheme)],
281+
) -> int:
282+
"""
283+
Validates whether a frontend request can access information about this session
284+
"""
285+
visit_name = get_visit_name(session_id)
286+
if auth_url:
287+
if not (
288+
await submit_to_auth_endpoint(
289+
f"validate_visit_access/{visit_name}",
290+
request,
291+
token,
292+
)
293+
).get("valid"):
294+
raise HTTPException(
295+
status_code=status.HTTP_401_UNAUTHORIZED,
296+
detail="You do not have access to this visit",
297+
headers={"WWW-Authenticate": "Bearer"},
298+
)
299+
return session_id
300+
301+
297302
async def validate_user_instrument_access(
298303
instrument_name: str,
304+
request: Request,
299305
token: Annotated[str, Depends(oauth2_scheme)],
300306
) -> str:
301307
"""
302308
Validates whether a frontend request can access information about this instrument
303309
"""
304-
await submit_to_auth_endpoint(
305-
f"validate_instrument_access/{instrument_name}", token
306-
)
310+
if auth_url:
311+
if not (
312+
await submit_to_auth_endpoint(
313+
f"validate_instrument_access/{instrument_name}",
314+
request,
315+
token,
316+
)
317+
).get("valid"):
318+
raise HTTPException(
319+
status_code=status.HTTP_401_UNAUTHORIZED,
320+
detail="You do not have access to this instrument",
321+
headers={"WWW-Authenticate": "Bearer"},
322+
)
307323
return instrument_name
308324

309325

310-
# Set validation conditions for the session ID based on where the request is from
311-
MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
326+
# Create annotated session ID and instrument name for endpoints that need to verify them
312327
MurfeySessionIDInstrument = Annotated[
313328
int, Depends(validate_instrument_server_session_access)
314329
]
315-
330+
MurfeySessionIDFrontend = Annotated[int, Depends(validate_frontend_session_access)]
316331
MurfeyInstrumentNameFrontend = Annotated[str, Depends(validate_user_instrument_access)]
317332

318333

0 commit comments

Comments
 (0)