diff --git a/tests/test_schema_compatibility.py b/tests/test_schema_compatibility.py new file mode 100644 index 000000000..c60dafa2b --- /dev/null +++ b/tests/test_schema_compatibility.py @@ -0,0 +1,257 @@ +""" +Test that client (codecarbon/core/schemas.py) and server +(carbonserver/carbonserver/api/schemas.py) schemas are compatible. + +A mismatch between these schemas can cause silent data corruption or API errors. +This test was added to prevent regressions like the one fixed in PR #1189, +where `on_cloud` was typed as `str` on one side and `bool` on the other. + +Related issue: https://github.com/mlco2/codecarbon/issues/1190 +""" +import ast +from pathlib import Path + +import pytest + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- +REPO_ROOT = Path(__file__).parent.parent +CLIENT_SCHEMA_PATH = REPO_ROOT / "codecarbon" / "core" / "schemas.py" +SERVER_SCHEMA_PATH = REPO_ROOT / "carbonserver" / "carbonserver" / "api" / "schemas.py" + + +# --------------------------------------------------------------------------- +# AST helpers — parse schema files without importing them +# --------------------------------------------------------------------------- + + +def _is_pydantic_required_field(value_node: ast.expr) -> bool: + """ + Return True if the node is a Pydantic Field(...) call whose first + argument is the Ellipsis sentinel, meaning the field has no default. + + Example: duration: int = Field(..., gt=0) → required + """ + if not isinstance(value_node, ast.Call): + return False + func = value_node.func + func_name = ( + func.id + if isinstance(func, ast.Name) + else (func.attr if isinstance(func, ast.Attribute) else "") + ) + if func_name != "Field": + return False + # Field(...) — first positional arg is the Ellipsis literal + return bool( + value_node.args + and isinstance(value_node.args[0], ast.Constant) + and value_node.args[0].value is ... + ) + + +def _parse_class_fields(filepath: Path, class_name: str) -> dict[str, dict]: + """ + Parse *filepath* with the ``ast`` module and return a dict of + annotated fields declared directly on *class_name*: + + { + "field_name": { + "annotation": "Optional[bool]", # raw annotation string + "required": True | False, + }, + ... + } + + Works for both plain Python dataclasses and Pydantic BaseModel subclasses. + Un-annotated class-level assignments (e.g. ``model_config = …``) are + intentionally ignored. + """ + source = filepath.read_text(encoding="utf-8") + tree = ast.parse(source, filename=str(filepath)) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == class_name: + fields: dict[str, dict] = {} + for item in node.body: + if not ( + isinstance(item, ast.AnnAssign) + and isinstance(item.target, ast.Name) + ): + continue + + field_name = item.target.id + annotation = ast.unparse(item.annotation) + + # A field is *required* when: + # 1. It has no right-hand side at all (pure annotation) + # 2. Its right-hand side is Pydantic's Field(...) sentinel + if item.value is None or _is_pydantic_required_field(item.value): + required = True + else: + required = False + + fields[field_name] = {"annotation": annotation, "required": required} + return fields + + return {} # class not found — tests that call this will get empty dicts + + +# --------------------------------------------------------------------------- +# Type-compatibility helpers +# --------------------------------------------------------------------------- + + +def _unwrap_optional(annotation: str) -> str: + """ + Strip a single ``Optional[X]`` wrapper to return the inner type ``X``. + + >>> _unwrap_optional("Optional[bool]") + 'bool' + >>> _unwrap_optional("bool") + 'bool' + """ + if annotation.startswith("Optional[") and annotation.endswith("]"): + return annotation[len("Optional[") : -1] + return annotation + + +# Pydantic coerces these client-side types to the server-side types at +# validation time, so they are considered wire-compatible. +# Key = client annotation (unwrapped) +# Value = set of acceptable server annotations (unwrapped) +_COMPATIBLE_CORE_TYPES: dict[str, set[str]] = { + # The client uses plain `str` for UUIDs and datetime strings; + # Pydantic on the server will parse those correctly. + "str": {"str", "UUID", "datetime"}, + "UUID": {"UUID", "str"}, + "datetime": {"datetime", "str"}, + # Scalar types must match exactly. + "bool": {"bool"}, + "int": {"int"}, + "float": {"float"}, +} + + +def _types_compatible(client_annotation: str, server_annotation: str) -> bool: + """ + Return True when *client_annotation* is safe to send to an endpoint + that expects *server_annotation*. + + Optional wrappers are stripped before comparison so that, for example, + ``bool`` (client) and ``Optional[bool]`` (server) are treated as + compatible — the server simply allows None in addition to a bool value. + + A ``bool`` vs ``str`` mismatch (the bug fixed in #1189) would return + False and cause the test to fail. + """ + if client_annotation == server_annotation: + return True + + client_core = _unwrap_optional(client_annotation) + server_core = _unwrap_optional(server_annotation) + + if client_core == server_core: + return True + + return server_core in _COMPATIBLE_CORE_TYPES.get(client_core, set()) + + +# --------------------------------------------------------------------------- +# Schema pairs under test +# (label, client class name, server class name) +# --------------------------------------------------------------------------- +SCHEMA_PAIRS = [ + ("EmissionBase", "EmissionBase", "EmissionBase"), + ("RunBase", "RunBase", "RunBase"), + ("ExperimentBase", "ExperimentBase", "ExperimentBase"), + ("ProjectBase", "ProjectBase", "ProjectBase"), + ("OrganizationBase", "OrganizationBase", "OrganizationBase"), +] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_client_fields_exist_in_server(label: str, client_cls: str, server_cls: str): + """ + Every field declared in the client schema must also exist in the server + schema. If the server drops a field the client sends, the payload will + be silently ignored or rejected. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + missing = set(client_fields) - set(server_fields) + assert not missing, ( + f"[{label}] Fields present in client but missing from server schema: {missing}\n" + f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n" + f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + ) + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_required_server_fields_exist_in_client( + label: str, client_cls: str, server_cls: str +): + """ + Every *required* server field (one without a default value) must also + appear in the client schema. If the client never sends a required field, + every API call for that resource will fail validation. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + required_server_fields = { + name for name, meta in server_fields.items() if meta["required"] + } + missing = required_server_fields - set(client_fields) + assert not missing, ( + f"[{label}] Required server fields missing from client schema: {missing}\n" + f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n" + f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + ) + + +@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS) +def test_shared_field_types_are_compatible( + label: str, client_cls: str, server_cls: str +): + """ + For every field that appears in *both* schemas, the client-side type must + be wire-compatible with the server-side type. + + This test would have caught the ``on_cloud: str`` (server) vs + ``on_cloud: bool`` (client) mismatch that was fixed in PR #1189. + """ + client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls) + server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls) + + assert client_fields, f"[{label}] Could not parse client class '{client_cls}'" + assert server_fields, f"[{label}] Could not parse server class '{server_cls}'" + + shared = set(client_fields) & set(server_fields) + mismatches: list[str] = [] + + for field in sorted(shared): + c_type = client_fields[field]["annotation"] + s_type = server_fields[field]["annotation"] + if not _types_compatible(c_type, s_type): + mismatches.append(f" {field}: client={c_type!r} server={s_type!r}") + + assert not mismatches, ( + f"[{label}] Incompatible types between client and server schemas:\n" + + "\n".join(mismatches) + + f"\n client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}" + + f"\n server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}" + ) \ No newline at end of file