diff --git a/tests/conftest.py b/tests/conftest.py index e724d77..68097fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,2 +1,5 @@ -def pytest_addoption(parser) -> None: # type: ignore[no-untyped-def] +import pytest + + +def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption('--postgres-version', action='store', default='latest') diff --git a/tests/v1/test_unit/test_enums_and_exceptions.py b/tests/v1/test_unit/test_enums_and_exceptions.py new file mode 100644 index 0000000..785e25e --- /dev/null +++ b/tests/v1/test_unit/test_enums_and_exceptions.py @@ -0,0 +1,125 @@ +from uuid import uuid4 + +import pytest + +from notora.v1.enums.base import OrderByDirections +from notora.v1.exceptions.common import AlreadyExistsError, FKNotFoundError, NotFoundError + +_ENTITY_ID_INT = 42 + + +def test_order_by_directions_asc_value() -> None: + assert OrderByDirections.ASC.value == 'asc' + + +def test_order_by_directions_desc_value() -> None: + assert OrderByDirections.DESC.value == 'desc' + + +def test_order_by_directions_is_str_enum() -> None: + assert isinstance(OrderByDirections.ASC, str) + assert isinstance(OrderByDirections.DESC, str) + + +def test_order_by_directions_can_be_used_as_string() -> None: + assert f'{OrderByDirections.ASC}' == 'asc' + assert f'{OrderByDirections.DESC}' == 'desc' + + +def test_fk_not_found_error_stores_fk_name() -> None: + err = FKNotFoundError('msg', fk_name='user_id_fkey', table_name='post') + assert err.fk_name == 'user_id_fkey' + + +def test_fk_not_found_error_stores_table_name() -> None: + err = FKNotFoundError('msg', fk_name='user_id_fkey', table_name='post') + assert err.table_name == 'post' + + +def test_fk_not_found_error_message_is_accessible() -> None: + err = FKNotFoundError('Related object not found.', fk_name='fk', table_name='tbl') + assert str(err) == 'Related object not found.' + + +def test_fk_not_found_error_is_exception() -> None: + err = FKNotFoundError('msg', fk_name='fk', table_name='tbl') + assert isinstance(err, Exception) + + +def test_fk_not_found_error_can_be_raised_and_caught() -> None: + msg = 'err' + with pytest.raises(FKNotFoundError) as exc_info: + raise FKNotFoundError(msg, fk_name='fk', table_name='tbl') + assert exc_info.value.fk_name == 'fk' + + +def test_already_exists_error_default_message() -> None: + err = AlreadyExistsError() + assert str(err) == 'Entity already exists.' + + +def test_already_exists_error_custom_message() -> None: + err = AlreadyExistsError('Custom message.') + assert str(err) == 'Custom message.' + + +def test_already_exists_error_constraint_name_stored() -> None: + err = AlreadyExistsError(constraint_name='users_email_key') + assert err.constraint_name == 'users_email_key' + + +def test_already_exists_error_constraint_name_none_by_default() -> None: + err = AlreadyExistsError() + assert err.constraint_name is None + + +def test_already_exists_error_message_and_constraint_together() -> None: + err = AlreadyExistsError('Dup', constraint_name='my_constraint') + assert str(err) == 'Dup' + assert err.constraint_name == 'my_constraint' + + +def test_already_exists_error_is_exception() -> None: + assert isinstance(AlreadyExistsError(), Exception) + + +def test_already_exists_error_can_be_raised_and_caught() -> None: + msg = 'dup' + with pytest.raises(AlreadyExistsError): + raise AlreadyExistsError(msg) + + +def test_not_found_error_entity_id_none_by_default() -> None: + err: NotFoundError[None] = NotFoundError('not found') + assert err.entity_id is None + + +def test_not_found_error_entity_id_stored() -> None: + err = NotFoundError('not found', entity_id=_ENTITY_ID_INT) + assert err.entity_id == _ENTITY_ID_INT + + +def test_not_found_error_entity_id_uuid() -> None: + uid = uuid4() + err = NotFoundError('not found', entity_id=uid) + assert err.entity_id == uid + + +def test_not_found_error_message_preserved() -> None: + err: NotFoundError[None] = NotFoundError('Resource not found.') + assert str(err) == 'Resource not found.' + + +def test_not_found_error_is_exception() -> None: + assert isinstance(NotFoundError('x'), Exception) + + +def test_not_found_error_can_be_raised_and_caught() -> None: + msg = 'missing' + with pytest.raises(NotFoundError): + raise NotFoundError(msg) + + +def test_not_found_error_no_positional_args() -> None: + err: NotFoundError[None] = NotFoundError() + assert err.entity_id is None diff --git a/tests/v1/test_unit/test_schemas_base.py b/tests/v1/test_unit/test_schemas_base.py new file mode 100644 index 0000000..d7ecb13 --- /dev/null +++ b/tests/v1/test_unit/test_schemas_base.py @@ -0,0 +1,277 @@ +from datetime import UTC, datetime, timedelta, timezone +from ipaddress import IPv4Address, IPv6Address + +from notora.v1.enums.base import OrderByDirections +from notora.v1.schemas.base import ( + AdminMeta, + ClientMeta, + Filter, + OrderBy, + OrFilterGroup, + PaginationMetaSchema, + datetime_encoder, + normalize_datetime_to_utc, + utc_datetime_encoder, +) + +_LIMIT = 10 +_TOTAL_100 = 100 +_TOTAL_25 = 25 +_TOTAL_20 = 20 +_TOTAL_5 = 5 +_LAST_PAGE_10 = 10 +_LAST_PAGE_3 = 3 +_LAST_PAGE_2 = 2 +_SECOND_PAGE = 2 +_FILTER_COUNT = 2 + + +def test_normalize_datetime_naive_gets_utc_tzinfo() -> None: + naive = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + result = normalize_datetime_to_utc(naive) + assert result.tzinfo == UTC + + +def test_normalize_datetime_naive_value_unchanged() -> None: + naive = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + result = normalize_datetime_to_utc(naive) + assert result.replace(tzinfo=None) == naive + + +def test_normalize_datetime_utc_aware_unchanged() -> None: + aware = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + result = normalize_datetime_to_utc(aware) + assert result == aware + + +def test_normalize_datetime_offset_aware_converted_to_utc() -> None: + tz_plus2 = timezone(timedelta(hours=2)) + aware = datetime(2024, 6, 15, 14, 0, 0, tzinfo=tz_plus2) + result = normalize_datetime_to_utc(aware) + assert result == datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + assert result.tzinfo == UTC + + +def test_normalize_datetime_negative_offset_converted_to_utc() -> None: + tz_minus5 = timezone(timedelta(hours=-5)) + aware = datetime(2024, 6, 15, 7, 0, 0, tzinfo=tz_minus5) + result = normalize_datetime_to_utc(aware) + assert result == datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + + +def test_utc_datetime_encoder_returns_iso_string_with_z() -> None: + dt = datetime(2024, 1, 20, 9, 15, 30, tzinfo=UTC) + result = utc_datetime_encoder(dt) + assert result == '2024-01-20T09:15:30Z' + + +def test_utc_datetime_encoder_naive_datetime_treated_as_utc() -> None: + naive = datetime(2024, 1, 20, 9, 15, 30, tzinfo=UTC).replace(tzinfo=None) + result = utc_datetime_encoder(naive) + assert result == '2024-01-20T09:15:30Z' + + +def test_utc_datetime_encoder_offset_aware_converted() -> None: + tz_plus3 = timezone(timedelta(hours=3)) + dt = datetime(2024, 1, 20, 12, 15, 30, tzinfo=tz_plus3) + result = utc_datetime_encoder(dt) + assert result == '2024-01-20T09:15:30Z' + + +def test_utc_datetime_encoder_does_not_contain_plus00_00() -> None: + dt = datetime(2024, 6, 1, 0, 0, 0, tzinfo=UTC) + result = utc_datetime_encoder(dt) + assert '+00:00' not in result + + +def test_datetime_encoder_returns_float_timestamp() -> None: + dt = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + result = datetime_encoder(dt) + assert isinstance(result, float) + + +def test_datetime_encoder_naive_datetime_treated_as_utc() -> None: + naive = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + aware = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + assert datetime_encoder(naive) == datetime_encoder(aware) + + +def test_datetime_encoder_offset_datetime_normalized() -> None: + tz_plus2 = timezone(timedelta(hours=2)) + offset = datetime(2024, 1, 1, 2, 0, 0, tzinfo=tz_plus2) + utc = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + assert datetime_encoder(offset) == datetime_encoder(utc) + + +def test_pagination_meta_first_page_full() -> None: + meta = PaginationMetaSchema.calculate(total=_TOTAL_100, limit=_LIMIT, offset=0) + assert meta.current_page == 1 + assert meta.last_page == _LAST_PAGE_10 + assert meta.total == _TOTAL_100 + assert meta.limit == _LIMIT + + +def test_pagination_meta_second_page() -> None: + meta = PaginationMetaSchema.calculate(total=_TOTAL_100, limit=_LIMIT, offset=_LIMIT) + assert meta.current_page == _SECOND_PAGE + + +def test_pagination_meta_last_page_calculated() -> None: + meta = PaginationMetaSchema.calculate(total=_TOTAL_25, limit=_LIMIT, offset=0) + assert meta.last_page == _LAST_PAGE_3 + + +def test_pagination_meta_zero_total_gives_page_1() -> None: + meta = PaginationMetaSchema.calculate(total=0, limit=_LIMIT, offset=0) + assert meta.current_page == 1 + assert meta.last_page == 1 + assert meta.total == 0 + + +def test_pagination_meta_exact_multiple_total() -> None: + meta = PaginationMetaSchema.calculate(total=_TOTAL_20, limit=_LIMIT, offset=0) + assert meta.last_page == _LAST_PAGE_2 + + +def test_pagination_meta_total_less_than_limit_gives_page_1() -> None: + meta = PaginationMetaSchema.calculate(total=_TOTAL_5, limit=_LIMIT, offset=0) + assert meta.current_page == 1 + assert meta.last_page == 1 + + +def test_admin_meta_deleted_at_is_none_by_default() -> None: + meta = AdminMeta( + created_at=datetime(2024, 1, 1, tzinfo=UTC), + updated_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + assert meta.deleted_at is None + + +def test_admin_meta_deleted_at_can_be_set() -> None: + dt = datetime(2024, 6, 1, 12, 0, tzinfo=UTC) + meta = AdminMeta( + created_at=datetime(2024, 1, 1, tzinfo=UTC), + updated_at=datetime(2024, 1, 1, tzinfo=UTC), + deleted_at=dt, + ) + assert meta.deleted_at == dt + + +def test_admin_meta_timestamps_normalized_to_utc() -> None: + naive = datetime(2024, 1, 1, 10, 0, 0, tzinfo=UTC).replace(tzinfo=None) + meta = AdminMeta(created_at=naive, updated_at=naive) + assert meta.created_at.tzinfo == UTC + assert meta.updated_at.tzinfo == UTC + + +def test_client_meta_both_fields_none_by_default() -> None: + client = ClientMeta() + assert client.ip_address is None + assert client.user_agent is None + + +def test_client_meta_ipv4_address_accepted() -> None: + client = ClientMeta(ip_address=IPv4Address('127.0.0.1')) + assert isinstance(client.ip_address, IPv4Address) + + +def test_client_meta_ipv6_address_accepted() -> None: + client = ClientMeta(ip_address=IPv6Address('::1')) + assert isinstance(client.ip_address, IPv6Address) + + +def test_client_meta_user_agent_stored() -> None: + client = ClientMeta(user_agent='Mozilla/5.0') + assert client.user_agent == 'Mozilla/5.0' + + +def test_client_meta_ip_address_serialized_as_string() -> None: + client = ClientMeta(ip_address=IPv4Address('192.168.0.1')) + dumped = client.model_dump() + assert dumped['ip_address'] == '192.168.0.1' + + +def test_filter_default_op_is_equals_sign() -> None: + f = Filter(field='name', value='alice') + assert f.op == '=' + + +def test_filter_custom_op() -> None: + f = Filter(field='age', op='gt', value=18) + assert f.op == 'gt' + + +def test_filter_value_none_allowed() -> None: + f = Filter(field='deleted_at', op='is', value=None) + assert f.value is None + + +def test_filter_model_none_by_default() -> None: + f = Filter(field='name', value='x') + assert f.model is None + + +def test_filter_model_can_be_set() -> None: + class FakeModel: + pass + + f = Filter(field='name', value='x', model=FakeModel) + assert f.model is FakeModel + + +def test_filter_all_ops_accepted() -> None: + valid_ops = ( + 'eq', + '=', + 'ilike', + '~=', + 'is', + 'is_not', + 'in', + 'gt', + '>', + 'ge', + '>=', + 'lt', + '<', + 'le', + '<=', + ) + for op in valid_ops: + f = Filter(field='x', op=op, value=1) + assert f.op == op + + +def test_or_filter_group_stores_filters() -> None: + f1 = Filter(field='name', value='a') + f2 = Filter(field='name', value='b') + group = OrFilterGroup(filters=[f1, f2]) + assert len(group.filters) == _FILTER_COUNT + + +def test_or_filter_group_empty_filters_allowed() -> None: + group = OrFilterGroup(filters=[]) + assert group.filters == [] + + +def test_order_by_default_direction_is_asc() -> None: + ob = OrderBy(field='name') + assert ob.direction == OrderByDirections.ASC + + +def test_order_by_desc_direction() -> None: + ob = OrderBy(field='name', direction=OrderByDirections.DESC) + assert ob.direction == OrderByDirections.DESC + + +def test_order_by_model_none_by_default() -> None: + ob = OrderBy(field='name') + assert ob.model is None + + +def test_order_by_model_can_be_set() -> None: + class FakeModel: + pass + + ob = OrderBy(field='name', model=FakeModel) + assert ob.model is FakeModel diff --git a/tests/v1/test_unit/test_utils.py b/tests/v1/test_unit/test_utils.py new file mode 100644 index 0000000..1d205a5 --- /dev/null +++ b/tests/v1/test_unit/test_utils.py @@ -0,0 +1,60 @@ +from datetime import UTC, datetime + +import pytest + +from notora.utils.time import now_without_tz +from notora.utils.validation import validate_exclusive_presence + + +def test_now_without_tz_returns_datetime_without_tzinfo() -> None: + result = now_without_tz() + assert isinstance(result, datetime) + assert result.tzinfo is None + + +def test_now_without_tz_is_close_to_utc_now() -> None: + before = datetime.now(UTC).replace(tzinfo=None) + result = now_without_tz() + after = datetime.now(UTC).replace(tzinfo=None) + assert before <= result <= after + + +def test_now_without_tz_called_twice_is_non_decreasing() -> None: + first = now_without_tz() + second = now_without_tz() + assert first <= second + + +def test_validate_exclusive_presence_first_only_does_not_raise() -> None: + validate_exclusive_presence('value', None) + + +def test_validate_exclusive_presence_second_only_does_not_raise() -> None: + validate_exclusive_presence(None, 'value') + + +def test_validate_exclusive_presence_both_provided_raises() -> None: + with pytest.raises(ValueError, match='Exactly one'): + validate_exclusive_presence('a', 'b') + + +def test_validate_exclusive_presence_neither_provided_raises() -> None: + with pytest.raises(ValueError, match='Exactly one'): + validate_exclusive_presence(None, None) + + +def test_validate_exclusive_presence_falsy_non_none_first_counts_as_provided() -> None: + # 0, '', [] are not None — should NOT raise + validate_exclusive_presence(0, None) + validate_exclusive_presence('', None) + validate_exclusive_presence([], None) + + +def test_validate_exclusive_presence_falsy_non_none_both_raises() -> None: + with pytest.raises(ValueError): + validate_exclusive_presence(0, 0) + + +def test_validate_exclusive_presence_non_string_values_accepted() -> None: + validate_exclusive_presence(42, None) + validate_exclusive_presence(None, {'key': 'val'}) diff --git a/tests/v2/conftest.py b/tests/v2/conftest.py index b638fde..01c937b 100644 --- a/tests/v2/conftest.py +++ b/tests/v2/conftest.py @@ -26,7 +26,7 @@ @pytest.fixture(scope='session') -def postgres_db(request) -> Iterator[PostgresContainer]: # type: ignore[no-untyped-def] +def postgres_db(request: pytest.FixtureRequest) -> Iterator[PostgresContainer]: postgres_version = request.config.getoption('--postgres-version') with PostgresContainer(f'postgres:{postgres_version}') as db: yield db diff --git a/tests/v2/test_unit/test_exceptions.py b/tests/v2/test_unit/test_exceptions.py new file mode 100644 index 0000000..85ec359 --- /dev/null +++ b/tests/v2/test_unit/test_exceptions.py @@ -0,0 +1,107 @@ +from uuid import uuid4 + +import pytest + +from notora.v2.exceptions.common import AlreadyExistsError, FKNotFoundError, NotFoundError + +_ENTITY_ID_INT = 42 + + +def test_fk_not_found_error_stores_fk_name() -> None: + err = FKNotFoundError('msg', fk_name='profile_user_id_fkey', table_name='profile') + assert err.fk_name == 'profile_user_id_fkey' + + +def test_fk_not_found_error_stores_table_name() -> None: + err = FKNotFoundError('msg', fk_name='fk', table_name='orders') + assert err.table_name == 'orders' + + +def test_fk_not_found_error_message_is_accessible() -> None: + err = FKNotFoundError('Related object not found.', fk_name='fk', table_name='tbl') + assert str(err) == 'Related object not found.' + + +def test_fk_not_found_error_is_exception() -> None: + err = FKNotFoundError('msg', fk_name='fk', table_name='tbl') + assert isinstance(err, Exception) + + +def test_fk_not_found_error_can_be_raised_and_caught() -> None: + msg = 'err' + with pytest.raises(FKNotFoundError) as exc_info: + raise FKNotFoundError(msg, fk_name='fk', table_name='tbl') + assert exc_info.value.fk_name == 'fk' + assert exc_info.value.table_name == 'tbl' + + +def test_already_exists_error_default_message() -> None: + err = AlreadyExistsError() + assert str(err) == 'Entity already exists.' + + +def test_already_exists_error_custom_message() -> None: + err = AlreadyExistsError('Custom message.') + assert str(err) == 'Custom message.' + + +def test_already_exists_error_constraint_name_stored() -> None: + err = AlreadyExistsError(constraint_name='users_email_key') + assert err.constraint_name == 'users_email_key' + + +def test_already_exists_error_constraint_name_none_by_default() -> None: + err = AlreadyExistsError() + assert err.constraint_name is None + + +def test_already_exists_error_message_and_constraint_together() -> None: + err = AlreadyExistsError('Dup', constraint_name='my_constraint') + assert str(err) == 'Dup' + assert err.constraint_name == 'my_constraint' + + +def test_already_exists_error_is_exception() -> None: + assert isinstance(AlreadyExistsError(), Exception) + + +def test_already_exists_error_can_be_raised_and_caught() -> None: + msg = 'dup' + with pytest.raises(AlreadyExistsError): + raise AlreadyExistsError(msg) + + +def test_not_found_error_entity_id_none_by_default() -> None: + err: NotFoundError[None] = NotFoundError('not found') + assert err.entity_id is None + + +def test_not_found_error_entity_id_stored() -> None: + err = NotFoundError('not found', entity_id=_ENTITY_ID_INT) + assert err.entity_id == _ENTITY_ID_INT + + +def test_not_found_error_entity_id_uuid() -> None: + uid = uuid4() + err = NotFoundError('not found', entity_id=uid) + assert err.entity_id == uid + + +def test_not_found_error_message_preserved() -> None: + err: NotFoundError[None] = NotFoundError('Resource not found.') + assert str(err) == 'Resource not found.' + + +def test_not_found_error_is_exception() -> None: + assert isinstance(NotFoundError('x'), Exception) + + +def test_not_found_error_can_be_raised_and_caught() -> None: + msg = 'missing' + with pytest.raises(NotFoundError): + raise NotFoundError(msg) + + +def test_not_found_error_no_positional_args() -> None: + err: NotFoundError[None] = NotFoundError() + assert err.entity_id is None diff --git a/tests/v2/test_unit/test_factory.py b/tests/v2/test_unit/test_factory.py new file mode 100644 index 0000000..8ad0073 --- /dev/null +++ b/tests/v2/test_unit/test_factory.py @@ -0,0 +1,147 @@ +"""Tests for build_repository, build_service, and build_service_for_repo factories.""" + +from typing import Any + +import pytest +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column + +from notora.v2.models.base import GenericBaseModel, SoftDeletableModel +from notora.v2.repositories.base import Repository, SoftDeleteRepository +from notora.v2.repositories.config import RepoConfig +from notora.v2.repositories.factory import AnyRepository, build_repository +from notora.v2.schemas.base import BaseResponseSchema +from notora.v2.services.base import RepositoryService, SoftDeleteRepositoryService +from notora.v2.services.factory import AnyService, build_service, build_service_for_repo + +_DEFAULT_LIMIT = 7 +_REPO_CONFIG_LIMIT = 3 + + +class _Widget(GenericBaseModel): + name: Mapped[str] = mapped_column(String) + + +class _SoftWidget(SoftDeletableModel): + name: Mapped[str] = mapped_column(String) + + +class _WidgetSchema(BaseResponseSchema): + pass + + +def test_build_repository_returns_standard_repo_by_default() -> None: + repo: AnyRepository[object, _Widget] = build_repository(_Widget) + assert isinstance(repo, Repository) + assert not isinstance(repo, SoftDeleteRepository) + + +def test_build_repository_soft_delete_flag_returns_soft_delete_repo() -> None: + repo: AnyRepository[object, _SoftWidget] = build_repository(_SoftWidget, soft_delete=True) + assert isinstance(repo, SoftDeleteRepository) + + +def test_build_repository_config_is_applied() -> None: + config = RepoConfig[_Widget](default_limit=_DEFAULT_LIMIT) + repo: AnyRepository[object, _Widget] = build_repository(_Widget, config=config) + assert repo.default_limit == _DEFAULT_LIMIT + + +def test_build_repository_custom_repo_class_used() -> None: + class _CustomRepo(Repository[object, _Widget]): + pass + + repo: AnyRepository[object, _Widget] = build_repository(_Widget, repo_cls=_CustomRepo) + assert isinstance(repo, _CustomRepo) + + +def test_build_repository_model_attribute_set() -> None: + repo: AnyRepository[object, _Widget] = build_repository(_Widget) + assert repo.model is _Widget + + +def test_build_service_returns_repository_service_by_default() -> None: + svc: AnyService[object, _Widget, Any, Any] = build_service(_Widget) + assert isinstance(svc, RepositoryService) + + +def test_build_service_soft_delete_flag_returns_soft_delete_service() -> None: + svc: AnyService[object, _SoftWidget, Any, Any] = build_service(_SoftWidget, soft_delete=True) + assert isinstance(svc, SoftDeleteRepositoryService) + + +def test_build_service_custom_repo_passed_directly() -> None: + repo = Repository[object, _Widget](_Widget) + svc: AnyService[object, _Widget, Any, Any] = build_service(_Widget, repo=repo) + assert isinstance(svc, RepositoryService) + assert svc.repo is repo + + +def test_build_service_soft_delete_repo_infers_soft_delete_service() -> None: + repo = SoftDeleteRepository[object, _SoftWidget](_SoftWidget) + svc: AnyService[object, _SoftWidget, Any, Any] = build_service(_SoftWidget, repo=repo) + assert isinstance(svc, SoftDeleteRepositoryService) + + +def test_build_service_soft_delete_service_class_with_non_soft_delete_repo_raises() -> None: + repo = Repository[object, _Widget](_Widget) + with pytest.raises(TypeError, match='Soft-delete service requires'): + build_service(_Widget, repo=repo, service_cls=SoftDeleteRepositoryService) + + +def test_build_service_soft_delete_flag_with_explicit_service_class() -> None: + svc: AnyService[object, _SoftWidget, Any, Any] = build_service( + _SoftWidget, + soft_delete=True, + service_cls=SoftDeleteRepositoryService, + ) + assert isinstance(svc, SoftDeleteRepositoryService) + + +def test_build_service_repo_config_applied() -> None: + repo_config = RepoConfig[_Widget](default_limit=_REPO_CONFIG_LIMIT) + svc: AnyService[object, _Widget, Any, Any] = build_service(_Widget, repo_config=repo_config) + assert svc.repo.default_limit == _REPO_CONFIG_LIMIT + + +def test_build_service_soft_delete_repo_with_soft_delete_true() -> None: + repo = SoftDeleteRepository[object, _SoftWidget](_SoftWidget) + svc: AnyService[object, _SoftWidget, Any, Any] = build_service( + _SoftWidget, soft_delete=True, repo=repo + ) + assert isinstance(svc, SoftDeleteRepositoryService) + + +def test_build_service_for_repo_standard_repo_returns_repository_service() -> None: + repo = Repository[object, _Widget](_Widget) + svc: AnyService[object, _Widget, Any, Any] = build_service_for_repo(repo) + assert isinstance(svc, RepositoryService) + + +def test_build_service_for_repo_soft_delete_repo_returns_soft_delete_service() -> None: + repo = SoftDeleteRepository[object, _SoftWidget](_SoftWidget) + svc: AnyService[object, _SoftWidget, Any, Any] = build_service_for_repo(repo) + assert isinstance(svc, SoftDeleteRepositoryService) + + +def test_build_service_for_repo_custom_service_class_used() -> None: + class _CustomService(RepositoryService[object, _Widget, _WidgetSchema]): + pass + + repo = Repository[object, _Widget](_Widget) + svc: AnyService[object, _Widget, Any, Any] = build_service_for_repo( + repo, service_cls=_CustomService + ) + assert isinstance(svc, _CustomService) + + +def test_build_service_for_repo_soft_delete_service_cls_with_non_soft_delete_repo_raises() -> None: + repo = Repository[object, _Widget](_Widget) + with pytest.raises(TypeError, match='Soft-delete service requires'): + build_service_for_repo(repo, service_cls=SoftDeleteRepositoryService) + + +def test_build_service_for_repo_repo_is_wired_to_service() -> None: + repo = Repository[object, _Widget](_Widget) + svc: AnyService[object, _Widget, Any, Any] = build_service_for_repo(repo) + assert svc.repo is repo diff --git a/tests/v2/test_unit/test_payload_mixin.py b/tests/v2/test_unit/test_payload_mixin.py new file mode 100644 index 0000000..8a0d3d4 --- /dev/null +++ b/tests/v2/test_unit/test_payload_mixin.py @@ -0,0 +1,51 @@ +"""Tests for PayloadMixin._dump_payload.""" + +from pydantic import BaseModel as PydanticModel + +from notora.v2.services.mixins.payload import PayloadMixin + + +class _SomeSchema(PydanticModel): + name: str + score: int = 0 + + +def test_payload_mixin_dict_input_returned_as_copy() -> None: + original = {'name': 'Alice', 'score': 5} + result = PayloadMixin._dump_payload(original, exclude_unset=True) + assert result == original + # Ensure it's a copy, not the same object + result['name'] = 'Bob' + assert original['name'] == 'Alice' + + +def test_payload_mixin_pydantic_model_dump_with_exclude_unset_true() -> None: + schema = _SomeSchema(name='Alice') + result = PayloadMixin._dump_payload(schema, exclude_unset=True) + # 'score' was not explicitly set, so it should be excluded + assert 'name' in result + assert 'score' not in result + + +def test_payload_mixin_pydantic_model_dump_with_exclude_unset_false() -> None: + schema = _SomeSchema(name='Alice') + result = PayloadMixin._dump_payload(schema, exclude_unset=False) + assert result == {'name': 'Alice', 'score': 0} + + +def test_payload_mixin_pydantic_model_fully_set() -> None: + score = 10 + schema = _SomeSchema(name='Bob', score=score) + result = PayloadMixin._dump_payload(schema, exclude_unset=True) + assert result == {'name': 'Bob', 'score': score} + + +def test_payload_mixin_empty_dict_returns_empty_dict() -> None: + result = PayloadMixin._dump_payload({}, exclude_unset=False) + assert result == {} + + +def test_payload_mixin_non_string_dict_values_preserved() -> None: + payload = {'count': 42, 'active': True, 'tags': ['a', 'b']} + result = PayloadMixin._dump_payload(payload, exclude_unset=True) + assert result == payload diff --git a/tests/v2/test_unit/test_pydantic_query_schemas.py b/tests/v2/test_unit/test_pydantic_query_schemas.py index 50aff41..ef0e677 100644 --- a/tests/v2/test_unit/test_pydantic_query_schemas.py +++ b/tests/v2/test_unit/test_pydantic_query_schemas.py @@ -1,11 +1,11 @@ -from typing import Annotated, Any, ClassVar, Literal, cast +from typing import Annotated, Any, ClassVar, Literal from uuid import UUID, uuid4 import pytest from pydantic import BaseModel, Field -from sqlalchemy import Boolean, Integer, String -from sqlalchemy.dialects import postgresql +from sqlalchemy import Boolean, Integer, String, create_engine from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.sql import ColumnElement, or_ @@ -59,12 +59,14 @@ class ThingOrdering(PydanticOrderBySchema[Thing]): } +_PG_DIALECT: Dialect = create_engine('postgresql+asyncpg://').dialect + + def _render(spec: FilterSpec[Any] | OrderSpec[Any]) -> str: assert not callable(spec) - clause = cast(ColumnElement[Any], spec) return str( - clause.compile( - dialect=postgresql.dialect(), # type: ignore[no-untyped-call] + spec.compile( + dialect=_PG_DIALECT, compile_kwargs={'literal_binds': True}, ), ) @@ -241,7 +243,9 @@ class WithFilter(BaseModel): def test_extract_annotated_filters_skips_non_filter_metadata() -> None: class Mixed(BaseModel): - name: Annotated[str | None, Filter(resolver=Thing.name)] = Field(default=None, description='X') + name: Annotated[str | None, Filter(resolver=Thing.name)] = Field( + default=None, description='X' + ) plain: str | None = None out = _extract_annotated_filters(Mixed) @@ -250,6 +254,7 @@ class Mixed(BaseModel): def test_extract_annotated_filters_raises_on_multiple_filters_in_one_field() -> None: with pytest.raises(TypeError, match='multiple Filter'): + class Conflict(BaseModel): name: Annotated[ str | None, @@ -319,6 +324,7 @@ class LegacyParent(PydanticFiltersSchema[Thing]): } with pytest.raises(TypeError, match='mixes legacy `filter_fields` ClassVar'): + class AnnotatedChild(LegacyParent): age: Annotated[int | None, Filter(resolver=Thing.age)] = None @@ -355,10 +361,12 @@ class ThingFiltersAnnotated(PydanticFiltersSchema[Thing]): is_active: Annotated[bool | None, Filter(resolver=Thing.is_active)] = None q: Annotated[ str | None, - Filter(predicate=lambda m, _op, v: or_( - m.name.ilike(f'%{v}%'), - m.owner_id.cast(String).ilike(f'%{v}%'), - )), + Filter( + predicate=lambda m, _op, v: or_( + m.name.ilike(f'%{v}%'), + m.owner_id.cast(String).ilike(f'%{v}%'), + ) + ), ] = None diff --git a/tests/v2/test_unit/test_query_dsl_tokens.py b/tests/v2/test_unit/test_query_dsl_tokens.py new file mode 100644 index 0000000..e28a7c0 --- /dev/null +++ b/tests/v2/test_unit/test_query_dsl_tokens.py @@ -0,0 +1,403 @@ +"""Tests for query_dsl token parsers, filter/sort clause builders, and build_query_params.""" + +from typing import Any + +import pytest +from pydantic import ValidationError +from sqlalchemy import Integer, String, create_engine, select +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlalchemy.sql import ColumnElement + +from notora.v2.models.base import GenericBaseModel +from notora.v2.repositories.query_dsl import ( + FilterField, + FilterToken, + QueryInput, + SortField, + SortToken, + apply_filter_operator, + build_filter_clauses, + build_query_params, + build_sort_clauses, + parse_filter_token, + parse_sort_token, + resolve_to_column, +) + +_MULTI_CLAUSE_COUNT = 2 +_POSITIVE_OFFSET = 100 +_POSITIVE_LIMIT = 50 +_LIMIT_SMALL = 5 +_OFFSET_SMALL = 10 + +_PG_DIALECT: Dialect = create_engine('postgresql+asyncpg://').dialect + + +def _render(clause: ColumnElement[Any] | InstrumentedAttribute[Any]) -> str: + return str( + clause.compile( + dialect=_PG_DIALECT, + compile_kwargs={'literal_binds': True}, + ) + ) + + +class SampleModel(GenericBaseModel): + name: Mapped[str] = mapped_column(String) + score: Mapped[int] = mapped_column(Integer) + + +def test_parse_filter_token_parses_field_op_value() -> None: + token = parse_filter_token('name:eq:alice') + assert token.field == 'name' + assert token.operator == 'eq' + assert token.raw_value == 'alice' + + +def test_parse_filter_token_parses_operator_only_for_isnull() -> None: + token = parse_filter_token('name:isnull') + assert token.field == 'name' + assert token.operator == 'isnull' + assert token.raw_value is None + + +def test_parse_filter_token_raises_for_missing_colon() -> None: + with pytest.raises(ValueError, match='"field:op:value"'): + parse_filter_token('nocolon') + + +def test_parse_filter_token_raises_for_empty_field_name() -> None: + with pytest.raises(ValueError, match='field name cannot be empty'): + parse_filter_token(':eq:value') + + +def test_parse_filter_token_raises_for_unsupported_operator() -> None: + with pytest.raises(ValueError, match='Unsupported filter operator'): + parse_filter_token('name:contains:hello') + + +def test_parse_filter_token_value_with_colons_preserved() -> None: + token = parse_filter_token('name:eq:a:b:c') + assert token.raw_value == 'a:b:c' + + +def test_parse_filter_token_whitespace_stripped_from_field_and_op() -> None: + token = parse_filter_token(' name : eq : alice ') + assert token.field == 'name' + assert token.operator == 'eq' + + +def test_parse_filter_token_whitespace_only_value_becomes_none() -> None: + token = parse_filter_token('name:eq: ') + assert token.raw_value is None + + +def test_parse_filter_token_all_operators_accepted() -> None: + valid_ops = ('eq', 'ne', 'lt', 'lte', 'gt', 'gte', 'in', 'ilike', 'isnull') + for op in valid_ops: + token = parse_filter_token(f'name:{op}:x') + assert token.operator == op + + +def test_parse_filter_token_isnull_with_false_value() -> None: + token = parse_filter_token('name:isnull:false') + assert token.raw_value == 'false' + + +def test_parse_filter_token_in_with_comma_separated_value() -> None: + token = parse_filter_token('score:in:1,2,3') + assert token.raw_value == '1,2,3' + + +def test_parse_sort_token_plain_field_is_ascending() -> None: + token = parse_sort_token('name') + assert token.field == 'name' + assert token.direction == 'asc' + + +def test_parse_sort_token_plus_prefix_is_ascending() -> None: + token = parse_sort_token('+name') + assert token.field == 'name' + assert token.direction == 'asc' + + +def test_parse_sort_token_minus_prefix_is_descending() -> None: + token = parse_sort_token('-score') + assert token.field == 'score' + assert token.direction == 'desc' + + +def test_parse_sort_token_empty_string_raises() -> None: + with pytest.raises(ValueError, match='cannot be empty'): + parse_sort_token('') + + +def test_parse_sort_token_only_minus_raises() -> None: + with pytest.raises(ValueError, match='cannot be empty'): + parse_sort_token('-') + + +def test_parse_sort_token_only_plus_raises() -> None: + with pytest.raises(ValueError, match='cannot be empty'): + parse_sort_token('+') + + +def test_parse_sort_token_whitespace_stripped() -> None: + token = parse_sort_token(' name ') + assert token.field == 'name' + + +def test_parse_sort_token_returns_sort_token_dataclass() -> None: + token = parse_sort_token('name') + assert isinstance(token, SortToken) + + +def test_apply_filter_operator_eq() -> None: + clause = apply_filter_operator(SampleModel.name, 'eq', 'alice') + assert "sample_model.name = 'alice'" in _render(clause) + + +def test_apply_filter_operator_ne() -> None: + clause = apply_filter_operator(SampleModel.name, 'ne', 'alice') + rendered = _render(clause) + assert 'sample_model.name' in rendered + assert '!=' in rendered or '<>' in rendered + + +def test_apply_filter_operator_lt() -> None: + clause = apply_filter_operator(SampleModel.score, 'lt', 5) + assert 'sample_model.score < 5' in _render(clause) + + +def test_apply_filter_operator_lte() -> None: + clause = apply_filter_operator(SampleModel.score, 'lte', 5) + assert 'sample_model.score <= 5' in _render(clause) + + +def test_apply_filter_operator_gt() -> None: + clause = apply_filter_operator(SampleModel.score, 'gt', 5) + assert 'sample_model.score > 5' in _render(clause) + + +def test_apply_filter_operator_gte() -> None: + clause = apply_filter_operator(SampleModel.score, 'gte', 5) + assert 'sample_model.score >= 5' in _render(clause) + + +def test_apply_filter_operator_in() -> None: + clause = apply_filter_operator(SampleModel.score, 'in', [1, 2, 3]) + assert 'IN' in _render(clause) + + +def test_apply_filter_operator_ilike() -> None: + clause = apply_filter_operator(SampleModel.name, 'ilike', '%alice%') + assert 'ILIKE' in _render(clause) + + +def test_apply_filter_operator_isnull_true() -> None: + clause = apply_filter_operator(SampleModel.name, 'isnull', value=True) + assert 'IS NULL' in _render(clause) + + +def test_apply_filter_operator_isnull_false() -> None: + clause = apply_filter_operator(SampleModel.name, 'isnull', value=False) + assert 'IS NOT NULL' in _render(clause) + + +def test_apply_filter_operator_unsupported_operator_raises() -> None: + with pytest.raises(ValueError, match='Unsupported filter operator'): + bad_op: Any = 'contains' + apply_filter_operator(SampleModel.name, bad_op, 'x') + + +def test_resolve_to_column_direct_column_returned_unchanged() -> None: + col = resolve_to_column(SampleModel.name, SampleModel) + assert 'sample_model.name' in _render(col) + + +def test_resolve_to_column_callable_resolver_called_with_model() -> None: + col = resolve_to_column(lambda m: m.score, SampleModel) + assert 'sample_model.score' in _render(col) + + +def test_build_filter_clauses_single_eq_clause() -> None: + tokens = [FilterToken(field='name', operator='eq', raw_value='alice')] + fields: dict[str, FilterField[SampleModel]] = { + 'name': FilterField(resolver=SampleModel.name, value_type=str) + } + clauses = build_filter_clauses(tokens, model=SampleModel, fields=fields) + assert len(clauses) == 1 + assert "sample_model.name = 'alice'" in _render(clauses[0]) + + +def test_build_filter_clauses_unknown_field_raises() -> None: + tokens = [FilterToken(field='unknown', operator='eq', raw_value='x')] + with pytest.raises(ValueError, match='Unsupported filter field'): + build_filter_clauses(tokens, model=SampleModel, fields={}) + + +def test_build_filter_clauses_disallowed_operator_raises() -> None: + tokens = [FilterToken(field='name', operator='gt', raw_value='5')] + fields: dict[str, FilterField[SampleModel]] = { + 'name': FilterField(resolver=SampleModel.name, operators=frozenset({'eq'})) + } + with pytest.raises(ValueError, match='Operator'): + build_filter_clauses(tokens, model=SampleModel, fields=fields) + + +def test_build_filter_clauses_predicate_field() -> None: + def pred(model: type[SampleModel], _op: str, value: str) -> ColumnElement[bool]: + return model.name.ilike(f'%{value}%') + + tokens = [FilterToken(field='q', operator='eq', raw_value='alice')] + fields = {'q': FilterField(predicate=pred)} + clauses = build_filter_clauses(tokens, model=SampleModel, fields=fields) + assert 'ILIKE' in _render(clauses[0]) + + +def test_build_filter_clauses_field_without_resolver_or_predicate_raises() -> None: + tokens = [FilterToken(field='name', operator='eq', raw_value='x')] + fields: dict[str, FilterField[SampleModel]] = {'name': FilterField()} + with pytest.raises(ValueError, match='resolver or predicate'): + build_filter_clauses(tokens, model=SampleModel, fields=fields) + + +def test_build_filter_clauses_empty_tokens_returns_empty() -> None: + clauses = build_filter_clauses([], model=SampleModel, fields={}) + assert clauses == [] + + +def test_build_filter_clauses_isnull_no_value() -> None: + tokens = [FilterToken(field='name', operator='isnull', raw_value=None)] + fields: dict[str, FilterField[SampleModel]] = {'name': FilterField(resolver=SampleModel.name)} + clauses = build_filter_clauses(tokens, model=SampleModel, fields=fields) + assert 'IS NULL' in _render(clauses[0]) + + +def test_build_filter_clauses_in_operator_requires_value() -> None: + tokens = [FilterToken(field='name', operator='in', raw_value=None)] + fields: dict[str, FilterField[SampleModel]] = {'name': FilterField(resolver=SampleModel.name)} + with pytest.raises(ValueError, match='requires a value'): + build_filter_clauses(tokens, model=SampleModel, fields=fields) + + +def test_build_filter_clauses_non_isnull_without_value_raises() -> None: + tokens = [FilterToken(field='name', operator='eq', raw_value=None)] + fields: dict[str, FilterField[SampleModel]] = {'name': FilterField(resolver=SampleModel.name)} + with pytest.raises(ValueError, match='requires a value'): + build_filter_clauses(tokens, model=SampleModel, fields=fields) + + +def test_build_filter_clauses_callable_resolver_in_field() -> None: + tokens = [FilterToken(field='name', operator='eq', raw_value='x')] + fields: dict[str, FilterField[SampleModel]] = { + 'name': FilterField(resolver=lambda m: m.name, value_type=str) + } + clauses = build_filter_clauses(tokens, model=SampleModel, fields=fields) + assert "sample_model.name = 'x'" in _render(clauses[0]) + + +def test_build_sort_clauses_ascending() -> None: + tokens = [SortToken(field='name', direction='asc')] + fields: dict[str, SortField[SampleModel]] = {'name': SortField(resolver=SampleModel.name)} + clauses = build_sort_clauses(tokens, model=SampleModel, fields=fields) + assert len(clauses) == 1 + assert 'ASC' in _render(clauses[0]) + + +def test_build_sort_clauses_descending() -> None: + tokens = [SortToken(field='score', direction='desc')] + fields: dict[str, SortField[SampleModel]] = {'score': SortField(resolver=SampleModel.score)} + clauses = build_sort_clauses(tokens, model=SampleModel, fields=fields) + assert 'DESC' in _render(clauses[0]) + + +def test_build_sort_clauses_unknown_field_raises() -> None: + tokens = [SortToken(field='unknown', direction='asc')] + with pytest.raises(ValueError, match='Unsupported sort field'): + build_sort_clauses(tokens, model=SampleModel, fields={}) + + +def test_build_sort_clauses_empty_tokens_returns_empty() -> None: + clauses = build_sort_clauses([], model=SampleModel, fields={}) + assert clauses == [] + + +def test_build_sort_clauses_callable_resolver() -> None: + tokens = [SortToken(field='name', direction='asc')] + fields: dict[str, SortField[SampleModel]] = {'name': SortField(resolver=lambda m: m.name)} + clauses = build_sort_clauses(tokens, model=SampleModel, fields=fields) + assert 'sample_model.name' in _render(clauses[0]) + + +def test_build_sort_clauses_multiple_tokens() -> None: + tokens = [ + SortToken(field='name', direction='asc'), + SortToken(field='score', direction='desc'), + ] + fields: dict[str, SortField[SampleModel]] = { + 'name': SortField(resolver=SampleModel.name), + 'score': SortField(resolver=SampleModel.score), + } + clauses = build_sort_clauses(tokens, model=SampleModel, fields=fields) + assert len(clauses) == _MULTI_CLAUSE_COUNT + + +def test_query_input_negative_offset_raises() -> None: + with pytest.raises(ValidationError, match='offset must be zero or a positive integer'): + QueryInput(offset=-1) + + +def test_query_input_zero_offset_accepted() -> None: + q = QueryInput(offset=0) + assert q.offset == 0 + + +def test_query_input_positive_offset_accepted() -> None: + q = QueryInput(offset=_POSITIVE_OFFSET) + assert q.offset == _POSITIVE_OFFSET + + +def test_query_input_none_limit_accepted() -> None: + q = QueryInput(limit=None) + assert q.limit is None + + +def test_query_input_positive_limit_accepted() -> None: + q = QueryInput(limit=_POSITIVE_LIMIT) + assert q.limit == _POSITIVE_LIMIT + + +def test_build_query_params_filters_without_filter_fields_raises() -> None: + query = QueryInput(filter=['name:eq:x']) + with pytest.raises(ValueError, match='Filter fields mapping is required'): + build_query_params(query, model=SampleModel, filter_fields={}) + + +def test_build_query_params_sort_without_sort_fields_raises() -> None: + query = QueryInput(sort=['-score']) + with pytest.raises(ValueError, match='Sort fields mapping is required'): + build_query_params(query, model=SampleModel, sort_fields={}) + + +def test_build_query_params_no_filter_no_sort_returns_none_for_both() -> None: + query = QueryInput() + params = build_query_params(query, model=SampleModel) + assert params.filters is None + assert params.ordering is None + + +def test_build_query_params_explicit_limit_and_offset_passed_through() -> None: + query = QueryInput(limit=_LIMIT_SMALL, offset=_OFFSET_SMALL) + params = build_query_params(query, model=SampleModel) + assert params.limit == _LIMIT_SMALL + assert params.offset == _OFFSET_SMALL + + +def test_build_query_params_base_query_forwarded() -> None: + base = select(SampleModel) + query = QueryInput() + params = build_query_params(query, model=SampleModel, base_query=base) + assert params.base_query is base diff --git a/tests/v2/test_unit/test_schemas_base.py b/tests/v2/test_unit/test_schemas_base.py new file mode 100644 index 0000000..8b3af37 --- /dev/null +++ b/tests/v2/test_unit/test_schemas_base.py @@ -0,0 +1,50 @@ +"""Tests for v2 schemas.base — ClientMeta, PaginationMetaSchema.""" + +from ipaddress import IPv4Address, IPv6Address + +from notora.v2.schemas.base import ClientMeta, PaginationMetaSchema + +_TOTAL_100 = 100 +_LIMIT = 10 + + +def test_client_meta_both_fields_none_by_default() -> None: + client = ClientMeta() + assert client.ip_address is None + assert client.user_agent is None + + +def test_client_meta_ipv4_address_accepted() -> None: + client = ClientMeta(ip_address=IPv4Address('192.168.1.1')) + assert isinstance(client.ip_address, IPv4Address) + + +def test_client_meta_ipv6_address_accepted() -> None: + client = ClientMeta(ip_address=IPv6Address('::1')) + assert isinstance(client.ip_address, IPv6Address) + + +def test_client_meta_user_agent_stored() -> None: + client = ClientMeta(user_agent='Mozilla/5.0') + assert client.user_agent == 'Mozilla/5.0' + + +def test_client_meta_ip_serialized_as_string_in_dict() -> None: + client = ClientMeta(ip_address=IPv4Address('10.0.0.1')) + dumped = client.model_dump() + assert dumped['ip_address'] == '10.0.0.1' + + +def test_pagination_meta_negative_total_clamped_to_zero() -> None: + meta = PaginationMetaSchema.calculate(total=-5, limit=_LIMIT, offset=0) + assert meta.total == 0 + + +def test_pagination_meta_zero_total_preserved() -> None: + meta = PaginationMetaSchema.calculate(total=0, limit=_LIMIT, offset=0) + assert meta.total == 0 + + +def test_pagination_meta_positive_total_preserved() -> None: + meta = PaginationMetaSchema.calculate(total=_TOTAL_100, limit=_LIMIT, offset=0) + assert meta.total == _TOTAL_100 diff --git a/tests/v2/test_unit/test_serializer_mixin.py b/tests/v2/test_unit/test_serializer_mixin.py new file mode 100644 index 0000000..890b2ba --- /dev/null +++ b/tests/v2/test_unit/test_serializer_mixin.py @@ -0,0 +1,133 @@ +"""Tests for SerializerMixin — edge cases not covered by integration tests.""" + +import pytest +from pydantic import ConfigDict + +from notora.v2.models.base import GenericBaseModel +from notora.v2.schemas.base import BaseResponseSchema +from notora.v2.services.mixins.serializer import SerializerMixin + +_ITEM_COUNT = 5 + + +class _Item(GenericBaseModel): + pass + + +class _DetailSchema(BaseResponseSchema): + model_config = ConfigDict(from_attributes=True) + + +class _ListSchema(BaseResponseSchema): + model_config = ConfigDict(from_attributes=True) + + +def _make_obj() -> _Item: + return _Item() + + +def _make_mixin() -> SerializerMixin[_Item, _DetailSchema, _ListSchema]: + mixin: SerializerMixin[_Item, _DetailSchema, _ListSchema] = SerializerMixin() + return mixin + + +def test_serialize_one_uses_explicit_schema_arg() -> None: + mixin = _make_mixin() + item = _make_obj() + result = mixin.serialize_one(item, schema=_DetailSchema) + assert isinstance(result, _DetailSchema) + + +def test_serialize_one_falls_back_to_detail_schema_attribute() -> None: + mixin = _make_mixin() + mixin.detail_schema = _DetailSchema + item = _make_obj() + result = mixin.serialize_one(item) + assert isinstance(result, _DetailSchema) + + +def test_serialize_one_raises_when_no_schema_and_no_detail_schema() -> None: + mixin = _make_mixin() + item = _make_obj() + with pytest.raises(ValueError, match='schema is required'): + mixin.serialize_one(item) + + +def test_serialize_one_explicit_schema_overrides_detail_schema() -> None: + mixin = _make_mixin() + mixin.detail_schema = _DetailSchema + + class _AltSchema(_DetailSchema): + pass + + item = _make_obj() + result = mixin.serialize_one(item, schema=_AltSchema) + assert isinstance(result, _AltSchema) + + +def test_serialize_many_empty_list_returns_empty() -> None: + mixin = _make_mixin() + mixin.list_schema = _ListSchema + result = mixin.serialize_many([]) + assert result == [] + + +def test_serialize_many_uses_list_schema_by_default() -> None: + mixin = _make_mixin() + mixin.list_schema = _ListSchema + item = _make_obj() + results = mixin.serialize_many([item]) + assert all(isinstance(r, _ListSchema) for r in results) + + +def test_serialize_many_falls_back_to_detail_schema_when_list_schema_absent() -> None: + mixin = _make_mixin() + mixin.detail_schema = _DetailSchema + item = _make_obj() + results = mixin.serialize_many([item]) + assert all(isinstance(r, _DetailSchema) for r in results) + + +def test_serialize_many_explicit_schema_arg_overrides_list_schema() -> None: + mixin = _make_mixin() + mixin.list_schema = _ListSchema + + class _AltSchema(_ListSchema): + pass + + item = _make_obj() + results = mixin.serialize_many([item], schema=_AltSchema) + assert all(isinstance(r, _AltSchema) for r in results) + + +def test_serialize_many_prefer_list_schema_false_uses_explicit_schema() -> None: + mixin = _make_mixin() + mixin.detail_schema = _DetailSchema + mixin.list_schema = _ListSchema + item = _make_obj() + results = mixin.serialize_many([item], schema=_DetailSchema, prefer_list_schema=False) + assert all(isinstance(r, _DetailSchema) for r in results) + + +def test_serialize_many_raises_when_no_schema_at_all() -> None: + mixin = _make_mixin() + item = _make_obj() + with pytest.raises(ValueError, match='schema is required'): + mixin.serialize_many([item]) + + +def test_serialize_many_prefer_list_schema_false_no_schema_raises() -> None: + mixin = _make_mixin() + mixin.detail_schema = None + mixin.list_schema = None + item = _make_obj() + with pytest.raises(ValueError, match='schema is required'): + mixin.serialize_many([item], prefer_list_schema=False) + + +def test_serialize_many_serializes_multiple_items() -> None: + mixin = _make_mixin() + mixin.list_schema = _ListSchema + items = [_make_obj() for _ in range(_ITEM_COUNT)] + results = mixin.serialize_many(items) + assert len(results) == _ITEM_COUNT diff --git a/tests/v2/test_unit/test_updated_by_mixin.py b/tests/v2/test_unit/test_updated_by_mixin.py new file mode 100644 index 0000000..b2748ef --- /dev/null +++ b/tests/v2/test_unit/test_updated_by_mixin.py @@ -0,0 +1,79 @@ +"""Tests for UpdatedByServiceMixin.""" + +from uuid import UUID, uuid4 + +import pytest +from sqlalchemy import String, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from notora.v2.models.base import GenericBaseModel +from notora.v2.repositories.base import Repository +from notora.v2.services.mixins.updated_by import UpdatedByServiceMixin + + +class _WithUpdatedBy(GenericBaseModel): + name: Mapped[str] = mapped_column(String) + updated_by: Mapped[UUID | None] = mapped_column(Uuid, nullable=True) + + +class _WithoutUpdatedBy(GenericBaseModel): + name: Mapped[str] = mapped_column(String) + + +class _Mixin(UpdatedByServiceMixin[object, _WithUpdatedBy]): + def __init__(self) -> None: + self.repo = Repository[object, _WithUpdatedBy](_WithUpdatedBy) + + +class _MixinNoAttr(UpdatedByServiceMixin[object, _WithoutUpdatedBy]): + def __init__(self) -> None: + self.repo = Repository[object, _WithoutUpdatedBy](_WithoutUpdatedBy) + + +def test_apply_updated_by_actor_id_none_returns_payload_unchanged() -> None: + mixin = _Mixin() + payload = {'name': 'Alice'} + result = mixin._apply_updated_by(payload, actor_id=None) + assert result == {'name': 'Alice'} + + +def test_apply_updated_by_actor_id_set_injects_updated_by() -> None: + mixin = _Mixin() + actor_id = uuid4() + payload: dict[str, object] = {'name': 'Alice'} + result = mixin._apply_updated_by(payload, actor_id=actor_id) + assert result['updated_by'] == actor_id + + +def test_apply_updated_by_existing_value_not_overwritten() -> None: + mixin = _Mixin() + original_actor = uuid4() + new_actor = uuid4() + payload: dict[str, object] = {'name': 'Alice', 'updated_by': original_actor} + result = mixin._apply_updated_by(payload, actor_id=new_actor) + assert result['updated_by'] == original_actor + + +def test_apply_updated_by_model_without_attribute_raises() -> None: + mixin = _MixinNoAttr() + payload: dict[str, object] = {'name': 'Bob'} + with pytest.raises(ValueError, match='is not defined on'): + mixin._apply_updated_by(payload, actor_id=uuid4()) + + +def test_apply_updated_by_custom_attribute_name_used() -> None: + class _WithCustomAttr(GenericBaseModel): + name: Mapped[str] = mapped_column(String) + modified_by: Mapped[UUID | None] = mapped_column(Uuid, nullable=True) + + class _CustomMixin(UpdatedByServiceMixin[object, _WithCustomAttr]): + updated_by_attribute = 'modified_by' + + def __init__(self) -> None: + self.repo = Repository[object, _WithCustomAttr](_WithCustomAttr) + + mixin = _CustomMixin() + actor_id = uuid4() + payload: dict[str, object] = {'name': 'Charlie'} + result = mixin._apply_updated_by(payload, actor_id=actor_id) + assert result['modified_by'] == actor_id