Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions assets/coverage.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
101 changes: 93 additions & 8 deletions statica/validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""
The backus naur grammar for types is as follows:
T ::= Statica | int | float | str | None | (T1 | T2) | list[T] | set[T] | dict[T1, T2]
T ::= Statica
| int
| float
| str
| None
| (T1 | T2)
| list[T]
| set[T]
| dict[T1, T2]
| Literal[V1, ...]


Where:
- Statica: A class that inherits from Statica
Expand All @@ -16,11 +26,37 @@
from __future__ import annotations

from types import GenericAlias, UnionType
from typing import Any
from typing import Any, Literal, TypeGuard, Union

from statica.config import StaticaConfig, default_config
from statica.exceptions import ConstraintValidationError, TypeValidationError

########################################################################################
#### MARK: Types


class LiteralGenericAlias:
"""A type used in place of typing._LiteralGenericAlias to avoid private imports."""

__origin__ = Literal
__args__: tuple[Any, ...]


def is_literal_generic_alias(expected_type: Any) -> TypeGuard[LiteralGenericAlias]:
return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Literal


class UnionGenericAlias:
"""A type used in place of typing._UnionGenericAlias to avoid private imports."""

__origin__ = Union
__args__: tuple[Any, ...]


def is_union_generic_alias(expected_type: Any) -> TypeGuard[UnionGenericAlias]:
return hasattr(expected_type, "__origin__") and expected_type.__origin__ is Union


########################################################################################
#### MARK: Type Validation

Expand All @@ -35,16 +71,28 @@ def validate_or_raise(
are already initialized Statica objects.
"""

# Handle union types
# Handle generic aliases if native python types, e.g. list[int], dict[str, int]

if isinstance(expected_type, UnionType):
validate_type_union(value, expected_type, config)
if isinstance(expected_type, GenericAlias):
validate_type_generic_alias(value, expected_type, config)
return

# Handle generic aliases
# Handle parameterized generic types

if isinstance(expected_type, GenericAlias):
validate_type_generic_alias(value, expected_type, config)
if is_union_generic_alias(expected_type):
validate_type_union_generic_alias(value, expected_type, config)
return

# Handle Literal (e.g. Literal["a", "b"], with any number and type of values)

if is_literal_generic_alias(expected_type):
validate_literal(value, expected_type)
return

# Handle union types

if isinstance(expected_type, UnionType):
validate_type_union(value, expected_type, config)
return

# Handle all other types
Expand All @@ -59,6 +107,19 @@ def validate_or_raise(
raise TypeValidationError(msg)


def validate_literal(
value: Any,
expected_type: LiteralGenericAlias,
) -> None:
"""
Validate that the value matches one of the literals in the expected_type.
Throws TypeValidationError if the value is not one of the literals.
"""
if value not in expected_type.__args__:
msg = f"expected one of {expected_type.__args__}, got '{value}'"
raise TypeValidationError(msg)


def validate_type_union(
value: Any,
expected_type: UnionType,
Expand All @@ -83,6 +144,30 @@ def validate_type_union(
raise TypeValidationError(msg)


def validate_type_union_generic_alias(
value: Any,
expected_type: UnionGenericAlias,
config: StaticaConfig = default_config,
) -> None:
"""
Validate that the value matches one of the types in the UnionGenericAlias.
Throws TypeValidationError if the type does not match any of the union types.
"""
for sub_type in expected_type.__args__:
try:
validate_or_raise(value, sub_type, config)
except TypeValidationError:
continue # Try the next sub-type
else:
return # Exit if one of the sub-types matches

msg = config.type_error_message.format(
expected_type=expected_type.__args__,
found_type=type(value).__name__,
)
raise TypeValidationError(msg)


def validate_type_generic_alias(
value: Any,
expected_type: GenericAlias,
Expand Down
38 changes: 38 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

import pytest

from statica import Field, Statica, TypeValidationError
Expand Down Expand Up @@ -98,3 +100,39 @@ class UnsupportedGeneric(Statica):

with pytest.raises(TypeValidationError):
UnsupportedGeneric(data=frozenset([1, 2, 3]))


def test_validate_literal() -> None:
class LiteralTest(Statica):
data: Literal["a", "b", "c"]
number: Literal[1, 2, 3]

i1 = LiteralTest.from_map({"data": "a", "number": 1})
assert i1.data == "a"
assert i1.number == 1

with pytest.raises(TypeValidationError):
LiteralTest.from_map({"data": "d", "number": 1})

with pytest.raises(TypeValidationError):
LiteralTest.from_map({"data": "a", "number": 4})


def test_validate_literal_optional() -> None:
class LiteralTest(Statica):
data: Literal["a", "b", "c"] | None
number: Literal[1, 2, 3] | None

i1 = LiteralTest.from_map({"data": "a", "number": 1})
assert i1.data == "a"
assert i1.number == 1

i2 = LiteralTest.from_map({"data": None, "number": None})
assert i2.data is None
assert i2.number is None

with pytest.raises(TypeValidationError):
LiteralTest.from_map({"data": "d", "number": 1})

with pytest.raises(TypeValidationError):
LiteralTest.from_map({"data": "a", "number": 4})