Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions app/api.py

This file was deleted.

24 changes: 14 additions & 10 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import alembic.migration as alembic_migration
import redis
from calypsso import get_calypsso_app
from fastapi import FastAPI, HTTPException, Request, Response, status
from fastapi import APIRouter, FastAPI, HTTPException, Request, Response, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -24,7 +24,6 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from app import api
from app.core.core_endpoints import coredata_core, models_core
from app.core.google_api.google_api import GoogleAPI
from app.core.groups import models_groups
Expand All @@ -41,7 +40,7 @@
get_redis_client,
init_state,
)
from app.module import all_modules, module_list
from app.module import get_all_modules, get_module_list, init_module_list
from app.types.exceptions import (
ContentHTTPException,
GoogleAPIInvalidCredentialsError,
Expand Down Expand Up @@ -240,7 +239,7 @@ async def run_factories(
hyperion_error_logger.info("Startup: Factories enabled")
# Importing the core_factory at the beginning of the factories.
factories_list: list[Factory] = []
for module in all_modules:
for module in get_all_modules():
if module.factory:
factories_list.append(module.factory)
hyperion_error_logger.info(
Expand Down Expand Up @@ -299,10 +298,11 @@ def initialize_module_visibility(
coredata_core.ModuleVisibilityAwareness,
db,
)
known_roots = module_awareness.roots

new_modules = [
module
for module in module_list
for module in get_module_list()
if module.root not in module_awareness.roots
]
# Is run to create default module visibilities or when the table is empty
Expand All @@ -311,6 +311,7 @@ def initialize_module_visibility(
f"Startup: Some modules visibility settings are empty, initializing them ({[module.root for module in new_modules]})",
)
for module in new_modules:
known_roots.append(module.root)
if module.default_allowed_groups_ids is not None:
for group_id in module.default_allowed_groups_ids:
module_group_visibility = models_core.ModuleGroupVisibility(
Expand Down Expand Up @@ -344,9 +345,7 @@ def initialize_module_visibility(
f"Startup: Could not add module visibility {module.root} in the database: {error}",
)
initialization.set_core_data_sync(
coredata_core.ModuleVisibilityAwareness(
roots=[module.root for module in module_list],
),
coredata_core.ModuleVisibilityAwareness(roots=known_roots),
db,
)
hyperion_error_logger.info(
Expand All @@ -365,7 +364,7 @@ async def initialize_notification_topics(
) -> None:
existing_topics = await get_notification_topic(db=db)
existing_topics_id = [topic.id for topic in existing_topics]
for module in all_modules:
for module in get_all_modules():
if module.registred_topics:
for registred_topic in module.registred_topics:
if registred_topic.id not in existing_topics_id:
Expand Down Expand Up @@ -623,13 +622,18 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[LifespanState, None]:
hyperion_error_logger=hyperion_error_logger,
)

init_module_list(settings=settings)

# Initialize app
app = FastAPI(
title="Hyperion",
version=settings.HYPERION_VERSION,
lifespan=lifespan,
)
app.include_router(api.api_router)
api_router = APIRouter()
for module in get_all_modules():
api_router.include_router(module.router)
app.include_router(api_router)
use_route_path_as_operation_ids(app)

app.add_middleware(
Expand Down
4 changes: 2 additions & 2 deletions app/core/checkout/endpoints_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NotificationResultContent,
)
from app.dependencies import get_db
from app.module import all_modules
from app.module import get_all_modules
from app.types.module import CoreModule

router = APIRouter(tags=["Checkout"])
Expand Down Expand Up @@ -134,7 +134,7 @@ async def webhook(

# If a callback is defined for the module, we want to call it
try:
for module in all_modules:
for module in get_all_modules():
if module.root == checkout.module:
if module.payment_callback is None:
hyperion_error_logger.info(
Expand Down
11 changes: 8 additions & 3 deletions app/core/core_endpoints/endpoints_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
is_user,
is_user_super_admin,
)
from app.module import module_list
from app.module import get_module_list
from app.types.module import CoreModule
from app.utils.tools import is_group_id_valid, patch_identity_in_text

Expand Down Expand Up @@ -219,7 +219,7 @@ async def get_module_visibility(
"""

return_module_visibilities = []
for module in module_list:
for module in get_module_list():
allowed_group_ids = await cruds_core.get_allowed_groups_by_root(
root=module.root,
db=db,
Expand Down Expand Up @@ -247,14 +247,19 @@ async def get_module_visibility(
async def get_user_modules_visibility(
db: AsyncSession = Depends(get_db),
user: models_users.CoreUser = Depends(is_user()),
settings: Settings = Depends(get_settings),
):
"""
Get group user accessible root

**This endpoint is only usable by everyone**
"""

return await cruds_core.get_modules_by_user(user=user, db=db)
modules = await cruds_core.get_modules_by_user(user=user, db=db)
if settings.RESTRICT_TO_MODULES:
return [module for module in modules if module in settings.RESTRICT_TO_MODULES]

return modules


@router.post(
Expand Down
4 changes: 4 additions & 0 deletions app/core/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def settings_customise_sources(
# If self registration is disabled, users will need to be invited by an administrator to be able to register
ALLOW_SELF_REGISTRATION: bool = True

# Restrict to a list of module roots
# CoreModules can not be disabled
RESTRICT_TO_MODULES: list[str] | None = None

############################
# PostgreSQL configuration #
############################
Expand Down
84 changes: 58 additions & 26 deletions app/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,69 @@
import logging
from pathlib import Path

from app.core.utils.config import Settings
from app.types.exceptions import InvalidModuleRootInDotenvError
from app.types.module import CoreModule, Module

hyperion_error_logger = logging.getLogger("hyperion.error")

module_list: list[Module] = []
core_module_list: list[CoreModule] = []
all_modules: list[CoreModule] = []

for endpoints_file in Path().glob("app/modules/*/endpoints_*.py"):
endpoint_module = importlib.import_module(
".".join(endpoints_file.with_suffix("").parts),
)
if hasattr(endpoint_module, "module"):
module: Module = endpoint_module.module
module_list.append(module)
all_modules.append(module)
else:
hyperion_error_logger.error(
f"Module {endpoints_file} does not declare a module. It won't be enabled.",
_module_list: list[Module] = []
_core_module_list: list[CoreModule] = []
_all_modules: list[Module | CoreModule] = []


def init_module_list(settings: Settings):
_module_list.clear()
_core_module_list.clear()

module_list = []
for endpoints_file in Path().glob("app/modules/*/endpoints_*.py"):
endpoint_module = importlib.import_module(
".".join(endpoints_file.with_suffix("").parts),
)
if hasattr(endpoint_module, "module"):
module: Module = endpoint_module.module
module_list.append(module)
else:
hyperion_error_logger.error(
f"Module {endpoints_file} does not declare a module. It won't be enabled.",
)

if settings.RESTRICT_TO_MODULES:
existing_module_roots = [module.root for module in module_list]
for root in settings.RESTRICT_TO_MODULES:
if root not in existing_module_roots:
raise InvalidModuleRootInDotenvError(root)
for module in module_list:
if (
settings.RESTRICT_TO_MODULES
and module.root not in settings.RESTRICT_TO_MODULES
):
continue
_module_list.append(module)
_all_modules.append(module)

for endpoints_file in Path().glob("app/core/*/endpoints_*.py"):
endpoint_module = importlib.import_module(
".".join(endpoints_file.with_suffix("").parts),
)
if hasattr(endpoint_module, "core_module"):
core_module: CoreModule = endpoint_module.core_module
core_module_list.append(core_module)
all_modules.append(core_module)
else:
hyperion_error_logger.error(
f"Core module {endpoints_file} does not declare a core module. It won't be enabled.",
for endpoints_file in Path().glob("app/core/*/endpoints_*.py"):
endpoint_module = importlib.import_module(
".".join(endpoints_file.with_suffix("").parts),
)
if hasattr(endpoint_module, "core_module"):
core_module: CoreModule = endpoint_module.core_module
_core_module_list.append(core_module)
_all_modules.append(core_module)
else:
hyperion_error_logger.error(
f"Core module {endpoints_file} does not declare a core module. It won't be enabled.",
)


def get_module_list() -> list[Module]:
return _module_list


def get_core_module_list() -> list[CoreModule]:
return _core_module_list


def get_all_modules() -> list[Module | CoreModule]:
return _all_modules
5 changes: 5 additions & 0 deletions app/types/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def __init__(self):
super().__init__("Google API is not configured in dotenv")


class InvalidModuleRootInDotenvError(Exception):
def __init__(self, root: str):
super().__init__(f"Module root {root} does not exist")


class ContentHTTPException(HTTPException):
"""
A custom HTTPException allowing to return custom content.
Expand Down
4 changes: 4 additions & 0 deletions config.template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ SQLITE_DB: "app.db"
# If True, will print all SQL queries in the console
DATABASE_DEBUG: False

# Restrict to a list of module roots
# CoreModules can not be disabled
#RESTRICT_TO_MODULES: []

#####################################
# SMTP configuration using starttls #
#####################################
Expand Down
8 changes: 4 additions & 4 deletions tests/test_checkout.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@ async def test_webhook_payment_callback(
factory=None,
)
mocker.patch(
"app.core.checkout.endpoints_checkout.all_modules",
[test_module],
"app.core.checkout.endpoints_checkout.get_all_modules",
return_value=[test_module],
)

response = client.post(
Expand Down Expand Up @@ -350,8 +350,8 @@ async def test_webhook_payment_callback_fail(
factory=None,
)
mocker.patch(
"app.core.checkout.endpoints_checkout.all_modules",
[test_module],
"app.core.checkout.endpoints_checkout.get_all_modules",
return_value=[test_module],
)

mocked_hyperion_security_logger = mocker.patch(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_factories.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from fastapi.testclient import TestClient

from app.module import all_modules
from app.module import get_all_modules
from tests.commons import get_TestingSessionLocal


@pytest.mark.parametrize("client", [True], indirect=True)
async def test_factories(client: TestClient) -> None:
async with get_TestingSessionLocal()() as db:
factories = [
module.factory for module in all_modules if module.factory is not None
module.factory for module in get_all_modules() if module.factory is not None
]
for factory in factories:
assert not await factory.should_run(
Expand Down
Loading