Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions tests/test_schema_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
"""
Test that client (codecarbon/core/schemas.py) and server
(carbonserver/carbonserver/api/schemas.py) schemas are compatible.

A mismatch between these schemas can cause silent data corruption or API errors.
This test was added to prevent regressions like the one fixed in PR #1189,
where `on_cloud` was typed as `str` on one side and `bool` on the other.

Related issue: https://github.com/mlco2/codecarbon/issues/1190
"""
import ast
from pathlib import Path

import pytest

# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
REPO_ROOT = Path(__file__).parent.parent
CLIENT_SCHEMA_PATH = REPO_ROOT / "codecarbon" / "core" / "schemas.py"
SERVER_SCHEMA_PATH = REPO_ROOT / "carbonserver" / "carbonserver" / "api" / "schemas.py"


# ---------------------------------------------------------------------------
# AST helpers — parse schema files without importing them
# ---------------------------------------------------------------------------


def _is_pydantic_required_field(value_node: ast.expr) -> bool:
"""
Return True if the node is a Pydantic Field(...) call whose first
argument is the Ellipsis sentinel, meaning the field has no default.

Example: duration: int = Field(..., gt=0) → required
"""
if not isinstance(value_node, ast.Call):
return False
func = value_node.func
func_name = (
func.id
if isinstance(func, ast.Name)
else (func.attr if isinstance(func, ast.Attribute) else "")
)
if func_name != "Field":
return False
# Field(...) — first positional arg is the Ellipsis literal
return bool(
value_node.args
and isinstance(value_node.args[0], ast.Constant)
and value_node.args[0].value is ...
)


def _parse_class_fields(filepath: Path, class_name: str) -> dict[str, dict]:
"""
Parse *filepath* with the ``ast`` module and return a dict of
annotated fields declared directly on *class_name*:

{
"field_name": {
"annotation": "Optional[bool]", # raw annotation string
"required": True | False,
},
...
}

Works for both plain Python dataclasses and Pydantic BaseModel subclasses.
Un-annotated class-level assignments (e.g. ``model_config = …``) are
intentionally ignored.
"""
source = filepath.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(filepath))

for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
fields: dict[str, dict] = {}
for item in node.body:
if not (
isinstance(item, ast.AnnAssign)
and isinstance(item.target, ast.Name)
):
continue

field_name = item.target.id
annotation = ast.unparse(item.annotation)

# A field is *required* when:
# 1. It has no right-hand side at all (pure annotation)
# 2. Its right-hand side is Pydantic's Field(...) sentinel
if item.value is None or _is_pydantic_required_field(item.value):
required = True
else:
required = False

fields[field_name] = {"annotation": annotation, "required": required}
return fields

return {} # class not found — tests that call this will get empty dicts


# ---------------------------------------------------------------------------
# Type-compatibility helpers
# ---------------------------------------------------------------------------


def _unwrap_optional(annotation: str) -> str:
"""
Strip a single ``Optional[X]`` wrapper to return the inner type ``X``.

>>> _unwrap_optional("Optional[bool]")
'bool'
>>> _unwrap_optional("bool")
'bool'
"""
if annotation.startswith("Optional[") and annotation.endswith("]"):
return annotation[len("Optional[") : -1]
return annotation


# Pydantic coerces these client-side types to the server-side types at
# validation time, so they are considered wire-compatible.
# Key = client annotation (unwrapped)
# Value = set of acceptable server annotations (unwrapped)
_COMPATIBLE_CORE_TYPES: dict[str, set[str]] = {
# The client uses plain `str` for UUIDs and datetime strings;
# Pydantic on the server will parse those correctly.
"str": {"str", "UUID", "datetime"},
"UUID": {"UUID", "str"},
"datetime": {"datetime", "str"},
# Scalar types must match exactly.
"bool": {"bool"},
"int": {"int"},
"float": {"float"},
}


def _types_compatible(client_annotation: str, server_annotation: str) -> bool:
"""
Return True when *client_annotation* is safe to send to an endpoint
that expects *server_annotation*.

Optional wrappers are stripped before comparison so that, for example,
``bool`` (client) and ``Optional[bool]`` (server) are treated as
compatible — the server simply allows None in addition to a bool value.

A ``bool`` vs ``str`` mismatch (the bug fixed in #1189) would return
False and cause the test to fail.
"""
if client_annotation == server_annotation:
return True

client_core = _unwrap_optional(client_annotation)
server_core = _unwrap_optional(server_annotation)

if client_core == server_core:
return True

return server_core in _COMPATIBLE_CORE_TYPES.get(client_core, set())


# ---------------------------------------------------------------------------
# Schema pairs under test
# (label, client class name, server class name)
# ---------------------------------------------------------------------------
SCHEMA_PAIRS = [
("EmissionBase", "EmissionBase", "EmissionBase"),
("RunBase", "RunBase", "RunBase"),
("ExperimentBase", "ExperimentBase", "ExperimentBase"),
("ProjectBase", "ProjectBase", "ProjectBase"),
("OrganizationBase", "OrganizationBase", "OrganizationBase"),
]


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS)
def test_client_fields_exist_in_server(label: str, client_cls: str, server_cls: str):
"""
Every field declared in the client schema must also exist in the server
schema. If the server drops a field the client sends, the payload will
be silently ignored or rejected.
"""
client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls)
server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls)

assert client_fields, f"[{label}] Could not parse client class '{client_cls}'"
assert server_fields, f"[{label}] Could not parse server class '{server_cls}'"

missing = set(client_fields) - set(server_fields)
assert not missing, (
f"[{label}] Fields present in client but missing from server schema: {missing}\n"
f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n"
f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}"
)


@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS)
def test_required_server_fields_exist_in_client(
label: str, client_cls: str, server_cls: str
):
"""
Every *required* server field (one without a default value) must also
appear in the client schema. If the client never sends a required field,
every API call for that resource will fail validation.
"""
client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls)
server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls)

assert client_fields, f"[{label}] Could not parse client class '{client_cls}'"
assert server_fields, f"[{label}] Could not parse server class '{server_cls}'"

required_server_fields = {
name for name, meta in server_fields.items() if meta["required"]
}
missing = required_server_fields - set(client_fields)
assert not missing, (
f"[{label}] Required server fields missing from client schema: {missing}\n"
f" client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}\n"
f" server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}"
)


@pytest.mark.parametrize("label,client_cls,server_cls", SCHEMA_PAIRS)
def test_shared_field_types_are_compatible(
label: str, client_cls: str, server_cls: str
):
"""
For every field that appears in *both* schemas, the client-side type must
be wire-compatible with the server-side type.

This test would have caught the ``on_cloud: str`` (server) vs
``on_cloud: bool`` (client) mismatch that was fixed in PR #1189.
"""
client_fields = _parse_class_fields(CLIENT_SCHEMA_PATH, client_cls)
server_fields = _parse_class_fields(SERVER_SCHEMA_PATH, server_cls)

assert client_fields, f"[{label}] Could not parse client class '{client_cls}'"
assert server_fields, f"[{label}] Could not parse server class '{server_cls}'"

shared = set(client_fields) & set(server_fields)
mismatches: list[str] = []

for field in sorted(shared):
c_type = client_fields[field]["annotation"]
s_type = server_fields[field]["annotation"]
if not _types_compatible(c_type, s_type):
mismatches.append(f" {field}: client={c_type!r} server={s_type!r}")

assert not mismatches, (
f"[{label}] Incompatible types between client and server schemas:\n"
+ "\n".join(mismatches)
+ f"\n client : {CLIENT_SCHEMA_PATH.relative_to(REPO_ROOT)}"
+ f"\n server : {SERVER_SCHEMA_PATH.relative_to(REPO_ROOT)}"
)