diff --git a/assets/coverage.svg b/assets/coverage.svg index 6bfc8fa..e5db27c 100644 --- a/assets/coverage.svg +++ b/assets/coverage.svg @@ -15,7 +15,7 @@ coverage coverage - 99% - 99% + 100% + 100% diff --git a/statica/validation.py b/statica/validation.py index 7b43cb1..f623aac 100644 --- a/statica/validation.py +++ b/statica/validation.py @@ -1,6 +1,16 @@ """ The backus naur grammar for types is as follows: -T ::= Statica | int | float | str | None | (T1 | T2) | list[T] | set[T] | dict[T1, T2] +T ::= Statica + | int + | float + | str + | None + | (T1 | T2) + | list[T] + | set[T] + | dict[T1, T2] + | Literal[V1, ...] + Where: - Statica: A class that inherits from Statica @@ -16,11 +26,37 @@ from __future__ import annotations from types import GenericAlias, UnionType -from typing import Any +from typing import Any, Literal, TypeGuard, Union from statica.config import StaticaConfig, default_config from statica.exceptions import ConstraintValidationError, TypeValidationError +######################################################################################## +#### MARK: Types + + +class LiteralGenericAlias: + """A type used in place of typing._LiteralGenericAlias to avoid private imports.""" + + __origin__ = Literal + __args__: tuple[Any, ...] + + +def is_literal_generic_alias(expected_type: Any) -> TypeGuard[LiteralGenericAlias]: + return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Literal + + +class UnionGenericAlias: + """A type used in place of typing._UnionGenericAlias to avoid private imports.""" + + __origin__ = Union + __args__: tuple[Any, ...] + + +def is_union_generic_alias(expected_type: Any) -> TypeGuard[UnionGenericAlias]: + return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Union + + ######################################################################################## #### MARK: Type Validation @@ -35,16 +71,28 @@ def validate_or_raise( are already initialized Statica objects. """ - # Handle union types + # Handle generic aliases if native python types, e.g. list[int], dict[str, int] - if isinstance(expected_type, UnionType): - validate_type_union(value, expected_type, config) + if isinstance(expected_type, GenericAlias): + validate_type_generic_alias(value, expected_type, config) return - # Handle generic aliases + # Handle parameterized generic types - if isinstance(expected_type, GenericAlias): - validate_type_generic_alias(value, expected_type, config) + if is_union_generic_alias(expected_type): + validate_type_union_generic_alias(value, expected_type, config) + return + + # Handle Literal (e.g. Literal["a", "b"], with any number and type of values) + + if is_literal_generic_alias(expected_type): + validate_literal(value, expected_type) + return + + # Handle union types + + if isinstance(expected_type, UnionType): + validate_type_union(value, expected_type, config) return # Handle all other types @@ -59,6 +107,19 @@ def validate_or_raise( raise TypeValidationError(msg) +def validate_literal( + value: Any, + expected_type: LiteralGenericAlias, +) -> None: + """ + Validate that the value matches one of the literals in the expected_type. + Throws TypeValidationError if the value is not one of the literals. + """ + if value not in expected_type.__args__: + msg = f"expected one of {expected_type.__args__}, got '{value}'" + raise TypeValidationError(msg) + + def validate_type_union( value: Any, expected_type: UnionType, @@ -83,6 +144,30 @@ def validate_type_union( raise TypeValidationError(msg) +def validate_type_union_generic_alias( + value: Any, + expected_type: UnionGenericAlias, + config: StaticaConfig = default_config, +) -> None: + """ + Validate that the value matches one of the types in the UnionGenericAlias. + Throws TypeValidationError if the type does not match any of the union types. + """ + for sub_type in expected_type.__args__: + try: + validate_or_raise(value, sub_type, config) + except TypeValidationError: + continue # Try the next sub-type + else: + return # Exit if one of the sub-types matches + + msg = config.type_error_message.format( + expected_type=expected_type.__args__, + found_type=type(value).__name__, + ) + raise TypeValidationError(msg) + + def validate_type_generic_alias( value: Any, expected_type: GenericAlias, diff --git a/tests/test_validation.py b/tests/test_validation.py index 5236117..1ee5084 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,3 +1,5 @@ +from typing import Literal + import pytest from statica import Field, Statica, TypeValidationError @@ -98,3 +100,39 @@ class UnsupportedGeneric(Statica): with pytest.raises(TypeValidationError): UnsupportedGeneric(data=frozenset([1, 2, 3])) + + +def test_validate_literal() -> None: + class LiteralTest(Statica): + data: Literal["a", "b", "c"] + number: Literal[1, 2, 3] + + i1 = LiteralTest.from_map({"data": "a", "number": 1}) + assert i1.data == "a" + assert i1.number == 1 + + with pytest.raises(TypeValidationError): + LiteralTest.from_map({"data": "d", "number": 1}) + + with pytest.raises(TypeValidationError): + LiteralTest.from_map({"data": "a", "number": 4}) + + +def test_validate_literal_optional() -> None: + class LiteralTest(Statica): + data: Literal["a", "b", "c"] | None + number: Literal[1, 2, 3] | None + + i1 = LiteralTest.from_map({"data": "a", "number": 1}) + assert i1.data == "a" + assert i1.number == 1 + + i2 = LiteralTest.from_map({"data": None, "number": None}) + assert i2.data is None + assert i2.number is None + + with pytest.raises(TypeValidationError): + LiteralTest.from_map({"data": "d", "number": 1}) + + with pytest.raises(TypeValidationError): + LiteralTest.from_map({"data": "a", "number": 4})