Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions infrahub_sdk/schema/repository.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel, ConfigDict, Field, field_validator

Expand All @@ -20,8 +20,6 @@

InfrahubNodeTypes = InfrahubNode | InfrahubNodeSync

ResourceClass = TypeVar("ResourceClass")


class InfrahubRepositoryConfigElement(BaseModel):
"""Class to regroup all elements of the Infrahub configuration for a repository for typing purpose."""
Expand Down Expand Up @@ -166,18 +164,6 @@ class InfrahubMenuConfig(InfrahubRepositoryConfigElement):
file_path: Path = Field(..., description="The file within the repository containing menu data.")


RESOURCE_MAP: dict[Any, str] = {
InfrahubJinja2TransformConfig: "jinja2_transforms",
InfrahubCheckDefinitionConfig: "check_definitions",
InfrahubRepositoryArtifactDefinitionConfig: "artifact_definitions",
InfrahubPythonTransformConfig: "python_transforms",
InfrahubGeneratorDefinitionConfig: "generator_definitions",
InfrahubRepositoryGraphQLConfig: "queries",
InfrahubObjectConfig: "objects",
InfrahubMenuConfig: "menus",
}


class InfrahubRepositoryConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
check_definitions: list[InfrahubCheckDefinitionConfig] = Field(
Expand Down Expand Up @@ -215,49 +201,56 @@ def unique_items(cls, v: list[Any]) -> list[Any]:
raise ValueError(f"Found multiples element with the same names: {dups}")
return v

def _has_resource(self, resource_id: str, resource_type: type[ResourceClass], resource_field: str = "name") -> bool:
return any(getattr(item, resource_field) == resource_id for item in getattr(self, RESOURCE_MAP[resource_type]))

def _get_resource(
self, resource_id: str, resource_type: type[ResourceClass], resource_field: str = "name"
) -> ResourceClass:
for item in getattr(self, RESOURCE_MAP[resource_type]):
if getattr(item, resource_field) == resource_id:
return item
raise ResourceNotDefinedError(f"Unable to find {resource_id!r} in {RESOURCE_MAP[resource_type]!r}")

def has_jinja2_transform(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig)
return any(item.name == name for item in self.jinja2_transforms)

def get_jinja2_transform(self, name: str) -> InfrahubJinja2TransformConfig:
return self._get_resource(resource_id=name, resource_type=InfrahubJinja2TransformConfig)
for item in self.jinja2_transforms:
if item.name == name:
return item
raise ResourceNotDefinedError(f"Unable to find {name!r} in 'jinja2_transforms'")

def has_check_definition(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubCheckDefinitionConfig)
return any(item.name == name for item in self.check_definitions)

def get_check_definition(self, name: str) -> InfrahubCheckDefinitionConfig:
return self._get_resource(resource_id=name, resource_type=InfrahubCheckDefinitionConfig)
for item in self.check_definitions:
if item.name == name:
return item
raise ResourceNotDefinedError(f"Unable to find {name!r} in 'check_definitions'")

def has_artifact_definition(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubRepositoryArtifactDefinitionConfig)
return any(item.name == name for item in self.artifact_definitions)

def get_artifact_definition(self, name: str) -> InfrahubRepositoryArtifactDefinitionConfig:
return self._get_resource(resource_id=name, resource_type=InfrahubRepositoryArtifactDefinitionConfig)
for item in self.artifact_definitions:
if item.name == name:
return item
raise ResourceNotDefinedError(f"Unable to find {name!r} in 'artifact_definitions'")

def has_generator_definition(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubGeneratorDefinitionConfig)
return any(item.name == name for item in self.generator_definitions)

def get_generator_definition(self, name: str) -> InfrahubGeneratorDefinitionConfig:
return self._get_resource(resource_id=name, resource_type=InfrahubGeneratorDefinitionConfig)
for item in self.generator_definitions:
if item.name == name:
return item
raise ResourceNotDefinedError(f"Unable to find {name!r} in 'generator_definitions'")

def has_python_transform(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubPythonTransformConfig)
return any(item.name == name for item in self.python_transforms)

def get_python_transform(self, name: str) -> InfrahubPythonTransformConfig:
return self._get_resource(resource_id=name, resource_type=InfrahubPythonTransformConfig)
for item in self.python_transforms:
if item.name == name:
return item
raise ResourceNotDefinedError(f"Unable to find {name!r} in 'python_transforms'")

def has_query(self, name: str) -> bool:
return self._has_resource(resource_id=name, resource_type=InfrahubRepositoryGraphQLConfig)
return any(item.name == name for item in self.queries)

def get_query(self, name: str) -> InfrahubRepositoryGraphQLConfig:
return self._get_resource(resource_id=name, resource_type=InfrahubRepositoryGraphQLConfig)
for item in self.queries:
if item.name == name:
return item
raise ResourceNotDefinedError(f"Unable to find {name!r} in 'queries'")
239 changes: 239 additions & 0 deletions tests/unit/sdk/test_schema_repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import pytest

from infrahub_sdk.exceptions import ResourceNotDefinedError
from infrahub_sdk.schema.repository import InfrahubRepositoryConfig


@pytest.fixture
def repo_config() -> InfrahubRepositoryConfig:
return InfrahubRepositoryConfig.model_validate(
{
"jinja2_transforms": [{"name": "j2_transform", "query": "q1", "template_path": "templates/foo.j2"}],
"check_definitions": [{"name": "my_check", "file_path": "check.py"}],
"artifact_definitions": [
{
"name": "my_artifact",
"parameters": {},
"content_type": "text/plain",
"targets": "group",
"transformation": "t",
}
],
"generator_definitions": [{"name": "my_generator", "file_path": "g.py", "query": "q", "targets": "grp"}],
"python_transforms": [{"name": "my_python_transform", "file_path": "pt.py"}],
"queries": [{"name": "my_query", "file_path": "q.gql"}],
}
)


# --- Duplicate name validation ---


def test_duplicate_jinja2_transforms_raises() -> None:
with pytest.raises(ValueError, match="same names"):
InfrahubRepositoryConfig.model_validate(
{
"jinja2_transforms": [
{"name": "dup", "query": "q", "template_path": "t.j2"},
{"name": "dup", "query": "q2", "template_path": "t2.j2"},
]
}
)


def test_duplicate_check_definitions_raises() -> None:
with pytest.raises(ValueError, match="same names"):
InfrahubRepositoryConfig.model_validate(
{
"check_definitions": [
{"name": "dup", "file_path": "check.py"},
{"name": "dup", "file_path": "check2.py"},
]
}
)


def test_duplicate_artifact_definitions_raises() -> None:
with pytest.raises(ValueError, match="same names"):
InfrahubRepositoryConfig.model_validate(
{
"artifact_definitions": [
{
"name": "dup",
"parameters": {},
"content_type": "text/plain",
"targets": "g",
"transformation": "t",
},
{
"name": "dup",
"parameters": {},
"content_type": "text/plain",
"targets": "g",
"transformation": "t",
},
]
}
)


def test_duplicate_python_transforms_raises() -> None:
with pytest.raises(ValueError, match="same names"):
InfrahubRepositoryConfig.model_validate(
{
"python_transforms": [
{"name": "dup", "file_path": "t.py"},
{"name": "dup", "file_path": "t2.py"},
]
}
)


def test_duplicate_generator_definitions_raises() -> None:
with pytest.raises(ValueError, match="same names"):
InfrahubRepositoryConfig.model_validate(
{
"generator_definitions": [
{"name": "dup", "file_path": "g.py", "query": "q", "targets": "grp"},
{"name": "dup", "file_path": "g2.py", "query": "q", "targets": "grp"},
]
}
)


def test_duplicate_queries_raises() -> None:
with pytest.raises(ValueError, match="same names"):
InfrahubRepositoryConfig.model_validate(
{
"queries": [
{"name": "dup", "file_path": "q.gql"},
{"name": "dup", "file_path": "q2.gql"},
]
}
)


# --- has_jinja2_transform / get_jinja2_transform ---


def test_has_jinja2_transform_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_jinja2_transform("j2_transform") is True


def test_has_jinja2_transform_not_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_jinja2_transform("missing") is False


def test_get_jinja2_transform_found(repo_config: InfrahubRepositoryConfig) -> None:
result = repo_config.get_jinja2_transform("j2_transform")
assert result.name == "j2_transform"


def test_get_jinja2_transform_not_found(repo_config: InfrahubRepositoryConfig) -> None:
with pytest.raises(ResourceNotDefinedError):
repo_config.get_jinja2_transform("missing")


# --- has_check_definition / get_check_definition ---


def test_has_check_definition_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_check_definition("my_check") is True


def test_has_check_definition_not_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_check_definition("missing") is False


def test_get_check_definition_found(repo_config: InfrahubRepositoryConfig) -> None:
result = repo_config.get_check_definition("my_check")
assert result.name == "my_check"


def test_get_check_definition_not_found(repo_config: InfrahubRepositoryConfig) -> None:
with pytest.raises(ResourceNotDefinedError):
repo_config.get_check_definition("missing")


# --- has_artifact_definition / get_artifact_definition ---


def test_has_artifact_definition_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_artifact_definition("my_artifact") is True


def test_has_artifact_definition_not_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_artifact_definition("missing") is False


def test_get_artifact_definition_found(repo_config: InfrahubRepositoryConfig) -> None:
result = repo_config.get_artifact_definition("my_artifact")
assert result.name == "my_artifact"


def test_get_artifact_definition_not_found(repo_config: InfrahubRepositoryConfig) -> None:
with pytest.raises(ResourceNotDefinedError):
repo_config.get_artifact_definition("missing")


# --- has_generator_definition / get_generator_definition ---


def test_has_generator_definition_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_generator_definition("my_generator") is True


def test_has_generator_definition_not_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_generator_definition("missing") is False


def test_get_generator_definition_found(repo_config: InfrahubRepositoryConfig) -> None:
result = repo_config.get_generator_definition("my_generator")
assert result.name == "my_generator"


def test_get_generator_definition_not_found(repo_config: InfrahubRepositoryConfig) -> None:
with pytest.raises(ResourceNotDefinedError):
repo_config.get_generator_definition("missing")


# --- has_python_transform / get_python_transform ---


def test_has_python_transform_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_python_transform("my_python_transform") is True


def test_has_python_transform_not_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_python_transform("missing") is False


def test_get_python_transform_found(repo_config: InfrahubRepositoryConfig) -> None:
result = repo_config.get_python_transform("my_python_transform")
assert result.name == "my_python_transform"


def test_get_python_transform_not_found(repo_config: InfrahubRepositoryConfig) -> None:
with pytest.raises(ResourceNotDefinedError):
repo_config.get_python_transform("missing")


# --- has_query / get_query ---


def test_has_query_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_query("my_query") is True


def test_has_query_not_found(repo_config: InfrahubRepositoryConfig) -> None:
assert repo_config.has_query("missing") is False


def test_get_query_found(repo_config: InfrahubRepositoryConfig) -> None:
result = repo_config.get_query("my_query")
assert result.name == "my_query"


def test_get_query_not_found(repo_config: InfrahubRepositoryConfig) -> None:
with pytest.raises(ResourceNotDefinedError):
repo_config.get_query("missing")
Loading