Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mrok/authentication/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mrok.authentication.backends.jwt import JWTAuthenticationBackend # noqa: F401
from mrok.authentication.backends.oidc import OIDCJWTAuthenticationBackend # noqa: F401
from mrok.authentication.base import AuthIdentity, BaseHTTPAuthBackend
from mrok.authentication.credentials import BearerCredentials, Credentials
Expand Down
40 changes: 40 additions & 0 deletions mrok/authentication/backends/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging

import jwt
from jwt import InvalidKeyError, InvalidTokenError

from mrok.authentication.base import AuthIdentity, BaseHTTPAuthBackend
from mrok.authentication.credentials import BearerCredentials, Credentials
from mrok.authentication.registry import register_authentication_backend
from mrok.types.proxy import Scope

logger = logging.getLogger("mrok.authentication")


@register_authentication_backend("jwt")
class JWTAuthenticationBackend(BaseHTTPAuthBackend):
def get_credentials(self, scope: Scope) -> Credentials | None:
return BearerCredentials.extract_from_asgi_scope(scope)

async def authenticate(self, credentials: Credentials) -> AuthIdentity | None:
try:
jwt_token = credentials.credentials
claims = jwt.decode(
jwt_token,
key=self.config.secret,
audience=self.config.audience,
algorithms=["HS256"],
options={
"verify_signature": True,
"verify_exp": True,
"verify_aud": True,
"verify_iss": False,
},
)
return AuthIdentity(subject=claims["sub"], metadata=claims)
except InvalidTokenError as error:
logger.exception(f"Error decoding token {error}")
return None
except InvalidKeyError as error:
logger.exception(f"Error decoding token {error}")
return None
2 changes: 1 addition & 1 deletion mrok/authentication/backends/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mrok.authentication.registry import register_authentication_backend
from mrok.types.proxy import Scope

logger = logging.getLogger("mrok.controller")
logger = logging.getLogger("mrok.authentication")


@register_authentication_backend("oidc")
Expand Down
5 changes: 3 additions & 2 deletions settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ controller:
auth:
backends:
- oidc
- jwt
oidc:
config_url: https://example.com/openid-configuration
audience: my-audience
audience: mrok
subject_claim: sub
jwt:
audience: my-audience
audience: mrok
pagination:
limit: 50

Expand Down
130 changes: 106 additions & 24 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import base64
import json
import tempfile
from collections.abc import AsyncGenerator, Generator
from collections.abc import AsyncGenerator, Callable, Generator
from datetime import UTC, datetime, timedelta
from typing import Any

Expand All @@ -20,6 +20,13 @@
from tests.types import ReceiveFactory, SendFactory, SettingsFactory, StatusEventFactory


@pytest.fixture
def settings(request, settings_factory):
if hasattr(request, "param"):
return request.param(settings_factory)
return settings_factory()


@pytest.fixture(scope="session")
def settings_factory() -> SettingsFactory:
def _get_settings(
Expand Down Expand Up @@ -166,6 +173,16 @@ def jwt_token(jwt_signing_key: str) -> str:
return jwt.encode(payload, jwt_signing_key, algorithm="RS256", headers={"kid": "test-key"})


@pytest.fixture()
def jwt_token_symmetric_key() -> str:
payload = {
"sub": "user123",
"aud": "mrok-audience",
"exp": datetime.now(UTC) + timedelta(minutes=30),
}
return jwt.encode(payload, "my_secret_key_for_testing_only12", algorithm="HS256")


@pytest.fixture()
def jwks_json() -> dict:
return {
Expand Down Expand Up @@ -196,13 +213,25 @@ def openid_config() -> dict:


@pytest.fixture()
def fastapi_app(settings_factory: SettingsFactory) -> FastAPI:
settings = settings_factory()
from mrok.controller.app import setup_app
def fastapi_app_factory(settings_factory: SettingsFactory) -> Callable[[Settings], FastAPI]:
def _fastapi_app(settings: Settings) -> FastAPI:
settings = settings or settings_factory()
from mrok.controller.app import setup_app

app = setup_app(settings)
app.dependency_overrides[get_settings] = lambda: settings
return app

app = setup_app(settings)
app.dependency_overrides[get_settings] = lambda: settings
return app
return _fastapi_app


@pytest.fixture()
def fastapi_app(
fastapi_app_factory: Callable[[Settings], FastAPI],
settings_factory: SettingsFactory,
) -> FastAPI:
settings = settings_factory()
return fastapi_app_factory(settings)


@pytest.fixture()
Expand All @@ -213,36 +242,76 @@ async def app_lifespan_manager(fastapi_app: FastAPI) -> AsyncGenerator[LifespanM

@pytest.fixture
async def api_client(
fastapi_app: FastAPI,
app_lifespan_manager: LifespanManager,
settings_factory: SettingsFactory,
fastapi_app_factory: Callable[[Settings], FastAPI],
settings: Settings,
httpx_mock: HTTPXMock,
openid_config: dict,
jwks_json: dict,
jwt_token: str,
app_lifespan_manager,
jwt_token_symmetric_key,
) -> AsyncGenerator[AsyncClient]:
settings = settings_factory()
httpx_mock.add_response(
method="GET",
url=settings.controller.auth.oidc.config_url,
json=openid_config,
is_reusable=True,
)
httpx_mock.add_response(
method="GET",
url="http://example.com/jwks.json",
json=jwks_json,
is_reusable=True,
)
if "oidc" in settings.controller.auth.backends:
httpx_mock.add_response(
method="GET",
url=settings.controller.auth.oidc.config_url,
json=openid_config,
is_reusable=True,
is_optional=True,
)
httpx_mock.add_response(
method="GET",
url="http://example.com/jwks.json",
json=jwks_json,
is_reusable=True,
is_optional=True,
)
app = fastapi_app_factory(settings)

async with AsyncClient(
transport=ASGITransport(app=app_lifespan_manager.app),
base_url=f"http://localhost/{fastapi_app.root_path.removeprefix('/')}/",
base_url=f"http://localhost/{app.root_path.removeprefix('/')}/",
headers={"Authorization": f"Bearer {jwt_token}"},
) as client:
yield client


@pytest.fixture
async def api_client_dual_backend(
fastapi_app_factory: Callable[[Settings], FastAPI],
settings_factory: SettingsFactory,
openid_config: dict,
jwks_json: dict,
jwt_token_symmetric_key: str,
) -> AsyncGenerator[AsyncClient]:
controller = {
"auth": {
"backends": ["oidc", "jwt"],
"oidc": {
"config_url": "http://example.com/openid-configuration",
"audience": "mrok-audience",
"subject_claim": "sub",
},
"jwt": {
"secret": "my_secret_key_for_testing_only12",
"audience": "mrok-audience",
},
}
}

settings = settings_factory(controller=controller)

app = fastapi_app_factory(settings)

async with LifespanManager(app) as manager:
async with AsyncClient(
transport=ASGITransport(app=manager.app),
base_url="http://localhost/",
headers={"Authorization": f"Bearer {jwt_token_symmetric_key}"},
) as client:
yield client


@pytest.fixture
def receive_factory() -> ReceiveFactory:
def _factory(messages: list[Message] | None = None) -> ASGIReceive:
Expand Down Expand Up @@ -445,3 +514,16 @@ def ziti_frontend_error_template_json_file(
f.write(ziti_frontend_error_template_json)
f.seek(0)
yield f.name


@pytest.fixture
def mock_empty_services(settings, httpx_mock):
httpx_mock.add_response(
method="GET",
url=f"{settings.ziti.base_urls.management}/edge/management/v1/services?limit=50&offset=0",
json={
"meta": {"pagination": {"totalCount": 0, "limit": 50, "offset": 0}},
"data": [],
},
is_optional=True,
)
Loading
Loading