From 978011b502001201805507e11f2182cfd06a50ef Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Mon, 24 Nov 2025 10:12:31 -0800 Subject: [PATCH 1/5] integration test for jumpstart with mig profile --- .../test_cli_jumpstart_inference_with_mig.py | 120 ++++++++++++++++++ .../test_sdk_jumpstart_inference_with_mig.py | 120 ++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py create mode 100644 test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py diff --git a/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py b/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py new file mode 100644 index 00000000..cadb9582 --- /dev/null +++ b/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py @@ -0,0 +1,120 @@ +import time +import pytest +import boto3 +from click.testing import CliRunner +from sagemaker.hyperpod.cli.commands.inference import ( + js_create, custom_invoke, js_list, js_describe, js_delete, js_get_operator_logs, js_list_pods +) +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from test.integration_tests.utils import get_time_str + +# --------- Test Configuration --------- +NAMESPACE = "integration" +VERSION = "1.1" +REGION = "us-east-2" +TIMEOUT_MINUTES = 20 +POLL_INTERVAL_SECONDS = 30 + +@pytest.fixture(scope="module") +def runner(): + return CliRunner() + +@pytest.fixture(scope="module") +def js_endpoint_name(): + return "js-mig-cli-integration-" + get_time_str() + +@pytest.fixture(scope="module") +def sagemaker_client(): + return boto3.client("sagemaker", region_name=REGION) + +# --------- JumpStart Endpoint Tests --------- +@pytest.mark.dependency(name="create") +def test_js_create(runner, js_endpoint_name): + result = runner.invoke(js_create, [ + "--namespace", NAMESPACE, + "--version", VERSION, + "--model-id", "deepseek-llm-r1-distill-qwen-1-5b", + "--instance-type", "ml.p4d.24xlarge", + "--endpoint-name", js_endpoint_name, + "--accelerator-partition-type", "mig-7g.40gb", + "--accelerator-partition-validation", "true", + ]) + assert result.exit_code == 0, result.output + + +@pytest.mark.dependency(depends=["create"]) +def test_js_list(runner, js_endpoint_name): + result = runner.invoke(js_list, ["--namespace", NAMESPACE]) + assert result.exit_code == 0 + assert js_endpoint_name in result.output + + +@pytest.mark.dependency(name="describe", depends=["create"]) +def test_js_describe(runner, js_endpoint_name): + result = runner.invoke(js_describe, [ + "--name", js_endpoint_name, + "--namespace", NAMESPACE, + "--full" + ]) + assert result.exit_code == 0 + assert js_endpoint_name in result.output + + +@pytest.mark.dependency(depends=["create", "describe"]) +def test_wait_until_inservice(js_endpoint_name): + """Poll SDK until specific JumpStart endpoint reaches DeploymentComplete""" + print(f"[INFO] Waiting for JumpStart endpoint '{js_endpoint_name}' to be DeploymentComplete...") + deadline = time.time() + (TIMEOUT_MINUTES * 60) + poll_count = 0 + + while time.time() < deadline: + poll_count += 1 + print(f"[DEBUG] Poll #{poll_count}: Checking endpoint status...") + + try: + ep = HPJumpStartEndpoint.get(name=js_endpoint_name, namespace=NAMESPACE) + state = ep.status.endpoints.sagemaker.state + print(f"[DEBUG] Current state: {state}") + if state == "CreationCompleted": + print("[INFO] Endpoint is in CreationCompleted state.") + return + + deployment_state = ep.status.deploymentStatus.deploymentObjectOverallState + if deployment_state == "DeploymentFailed": + pytest.fail("Endpoint deployment failed.") + + except Exception as e: + print(f"[ERROR] Exception during polling: {e}") + + time.sleep(POLL_INTERVAL_SECONDS) + + pytest.fail("[ERROR] Timed out waiting for endpoint to be DeploymentComplete") + + +@pytest.mark.dependency(depends=["create"]) +def test_custom_invoke(runner, js_endpoint_name): + result = runner.invoke(custom_invoke, [ + "--endpoint-name", js_endpoint_name, + "--body", '{"inputs": "What is the capital of USA?"}' + ]) + assert result.exit_code == 0 + assert "error" not in result.output.lower() + + +def test_js_get_operator_logs(runner): + result = runner.invoke(js_get_operator_logs, ["--since-hours", "1"]) + assert result.exit_code == 0 + + +def test_js_list_pods(runner): + result = runner.invoke(js_list_pods, ["--namespace", NAMESPACE]) + assert result.exit_code == 0 + + +@pytest.mark.dependency(depends=["create"]) +def test_js_delete(runner, js_endpoint_name): + result = runner.invoke(js_delete, [ + "--name", js_endpoint_name, + "--namespace", NAMESPACE + ]) + assert result.exit_code == 0 diff --git a/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py b/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py new file mode 100644 index 00000000..f8645352 --- /dev/null +++ b/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py @@ -0,0 +1,120 @@ +import time +import pytest +import boto3 +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import ( + Model, Server, SageMakerEndpoint, Validations +) +import sagemaker_core.main.code_injection.codec as codec +from test.integration_tests.utils import get_time_str +from sagemaker.hyperpod.common.config.metadata import Metadata + +# --------- Config --------- +NAMESPACE = "integration" +REGION = "us-east-2" +ENDPOINT_NAME = "js-mig-sdk-integration-" + get_time_str() + +INSTANCE_TYPE = "ml.p4d.24xlarge" +MODEL_ID = "deepseek-llm-r1-distill-qwen-1-5b" + +TIMEOUT_MINUTES = 20 +POLL_INTERVAL_SECONDS = 30 + +@pytest.fixture(scope="module") +def sagemaker_client(): + return boto3.client("sagemaker", region_name=REGION) + +@pytest.fixture(scope="module") +def endpoint_obj(): + model = Model(model_id=MODEL_ID) + validations = Validations(accelerator_partition_validation=True) + server = Server( + instance_type=INSTANCE_TYPE, + accelerator_partition_type="mig-7g.40gb", + validations=validations + ) + sm_endpoint = SageMakerEndpoint(name=ENDPOINT_NAME) + metadata = Metadata(name=ENDPOINT_NAME, namespace=NAMESPACE) + + return HPJumpStartEndpoint(metadata=metadata, model=model, server=server, sage_maker_endpoint=sm_endpoint) + +@pytest.mark.dependency(name="create") +def test_create_endpoint(endpoint_obj): + endpoint_obj.create() + assert endpoint_obj.metadata.name == ENDPOINT_NAME + +@pytest.mark.dependency(depends=["create"]) +def test_list_endpoint(): + endpoints = HPJumpStartEndpoint.list(namespace=NAMESPACE) + names = [ep.metadata.name for ep in endpoints] + assert ENDPOINT_NAME in names + +@pytest.mark.dependency(name="describe", depends=["create"]) +def test_get_endpoint(): + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + assert ep.metadata.name == ENDPOINT_NAME + assert ep.model.modelId == MODEL_ID + +@pytest.mark.dependency(depends=["create", "describe"]) +def test_wait_until_inservice(): + """Poll SDK until specific JumpStart endpoint reaches DeploymentComplete""" + print(f"[INFO] Waiting for JumpStart endpoint '{ENDPOINT_NAME}' to be DeploymentComplete...") + deadline = time.time() + (TIMEOUT_MINUTES * 60) + poll_count = 0 + + while time.time() < deadline: + poll_count += 1 + print(f"[DEBUG] Poll #{poll_count}: Checking endpoint status...") + + try: + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + state = ep.status.endpoints.sagemaker.state + print(f"[DEBUG] Current state: {state}") + if state == "CreationCompleted": + print("[INFO] Endpoint is in CreationCompleted state.") + return + + deployment_state = ep.status.deploymentStatus.deploymentObjectOverallState + if deployment_state == "DeploymentFailed": + pytest.fail("Endpoint deployment failed.") + + except Exception as e: + print(f"[ERROR] Exception during polling: {e}") + + time.sleep(POLL_INTERVAL_SECONDS) + + pytest.fail("[ERROR] Timed out waiting for endpoint to be DeploymentComplete") + + +@pytest.mark.dependency(depends=["create"]) +def test_invoke_endpoint(monkeypatch): + original_transform = codec.transform # Save original + + def mock_transform(data, shape, object_instance=None): + if "Body" in data: + return {"body": data["Body"].read().decode("utf-8")} + return original_transform(data, shape, object_instance) # Call original + + monkeypatch.setattr("sagemaker_core.main.resources.transform", mock_transform) + + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + data = '{"inputs":"What is the capital of USA?"}' + response = ep.invoke(body=data) + + assert "error" not in response.body.lower() + + +def test_get_operator_logs(): + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + logs = ep.get_operator_logs(since_hours=1) + assert logs + +def test_list_pods(): + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + pods = ep.list_pods(NAMESPACE) + assert pods + +@pytest.mark.dependency(depends=["create"]) +def test_delete_endpoint(): + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + ep.delete() From edbf72cddc8fa7ef9f9734fbe7b93746123fe215 Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Tue, 25 Nov 2025 10:18:53 -0800 Subject: [PATCH 2/5] template fix for mig with jumpstart --- .../hyperpod_jumpstart_inference_template/v1_1/template.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py index 580cf514..cf6db67c 100644 --- a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py @@ -1,8 +1,8 @@ TEMPLATE_CONTENT = """ -apiVersion: inference.sagemaker.aws.amazon.com/v1alpha1 +apiVersion: inference.sagemaker.aws.amazon.com/v1 kind: JumpStartModel metadata: - name: {{ model_id }} + name: {{ metadata_name or endpoint_name }} namespace: {{ namespace or "default" }} spec: model: @@ -18,4 +18,6 @@ {% if accelerator_partition_validation is not none %}validations: {% if accelerator_partition_validation is not none %} acceleratorPartitionValidation: {{ accelerator_partition_validation }}{% endif %} {% endif %} + tlsConfig: + tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }} """ \ No newline at end of file From bfd951427408427411494fbb0637f856d7428840 Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Tue, 25 Nov 2025 11:01:10 -0800 Subject: [PATCH 3/5] skipped mig tests until instances setup finished --- .../inference/cli/test_cli_jumpstart_inference_with_mig.py | 5 +++-- .../inference/sdk/test_sdk_jumpstart_inference_with_mig.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py b/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py index cadb9582..2b6f9bbc 100644 --- a/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py +++ b/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py @@ -28,6 +28,7 @@ def sagemaker_client(): return boto3.client("sagemaker", region_name=REGION) # --------- JumpStart Endpoint Tests --------- +@pytest.mark.skip(reason="Temporarily disabled") @pytest.mark.dependency(name="create") def test_js_create(runner, js_endpoint_name): result = runner.invoke(js_create, [ @@ -100,12 +101,12 @@ def test_custom_invoke(runner, js_endpoint_name): assert result.exit_code == 0 assert "error" not in result.output.lower() - +@pytest.mark.skip(reason="Temporarily disabled") def test_js_get_operator_logs(runner): result = runner.invoke(js_get_operator_logs, ["--since-hours", "1"]) assert result.exit_code == 0 - +@pytest.mark.skip(reason="Temporarily disabled") def test_js_list_pods(runner): result = runner.invoke(js_list_pods, ["--namespace", NAMESPACE]) assert result.exit_code == 0 diff --git a/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py b/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py index f8645352..48b0abc5 100644 --- a/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py +++ b/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py @@ -38,6 +38,7 @@ def endpoint_obj(): return HPJumpStartEndpoint(metadata=metadata, model=model, server=server, sage_maker_endpoint=sm_endpoint) +@pytest.mark.skip(reason="Temporarily disabled") @pytest.mark.dependency(name="create") def test_create_endpoint(endpoint_obj): endpoint_obj.create() @@ -104,11 +105,13 @@ def mock_transform(data, shape, object_instance=None): assert "error" not in response.body.lower() +@pytest.mark.skip(reason="Temporarily disabled") def test_get_operator_logs(): ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) logs = ep.get_operator_logs(since_hours=1) assert logs +@pytest.mark.skip(reason="Temporarily disabled") def test_list_pods(): ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) pods = ep.list_pods(NAMESPACE) From e2054ce508518c723f38a184acc209768448022b Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Mon, 1 Dec 2025 10:06:02 -0800 Subject: [PATCH 4/5] enable the mig integration tests --- .../inference/cli/test_cli_jumpstart_inference_with_mig.py | 3 --- .../inference/sdk/test_sdk_jumpstart_inference_with_mig.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py b/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py index 2b6f9bbc..bd385290 100644 --- a/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py +++ b/test/integration_tests/inference/cli/test_cli_jumpstart_inference_with_mig.py @@ -28,7 +28,6 @@ def sagemaker_client(): return boto3.client("sagemaker", region_name=REGION) # --------- JumpStart Endpoint Tests --------- -@pytest.mark.skip(reason="Temporarily disabled") @pytest.mark.dependency(name="create") def test_js_create(runner, js_endpoint_name): result = runner.invoke(js_create, [ @@ -101,12 +100,10 @@ def test_custom_invoke(runner, js_endpoint_name): assert result.exit_code == 0 assert "error" not in result.output.lower() -@pytest.mark.skip(reason="Temporarily disabled") def test_js_get_operator_logs(runner): result = runner.invoke(js_get_operator_logs, ["--since-hours", "1"]) assert result.exit_code == 0 -@pytest.mark.skip(reason="Temporarily disabled") def test_js_list_pods(runner): result = runner.invoke(js_list_pods, ["--namespace", NAMESPACE]) assert result.exit_code == 0 diff --git a/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py b/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py index 48b0abc5..f8645352 100644 --- a/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py +++ b/test/integration_tests/inference/sdk/test_sdk_jumpstart_inference_with_mig.py @@ -38,7 +38,6 @@ def endpoint_obj(): return HPJumpStartEndpoint(metadata=metadata, model=model, server=server, sage_maker_endpoint=sm_endpoint) -@pytest.mark.skip(reason="Temporarily disabled") @pytest.mark.dependency(name="create") def test_create_endpoint(endpoint_obj): endpoint_obj.create() @@ -105,13 +104,11 @@ def mock_transform(data, shape, object_instance=None): assert "error" not in response.body.lower() -@pytest.mark.skip(reason="Temporarily disabled") def test_get_operator_logs(): ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) logs = ep.get_operator_logs(since_hours=1) assert logs -@pytest.mark.skip(reason="Temporarily disabled") def test_list_pods(): ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) pods = ep.list_pods(NAMESPACE) From 4e789cab358f8fffa3fb3f4b660fb64c07246120 Mon Sep 17 00:00:00 2001 From: Chad Chiang Date: Mon, 9 Feb 2026 14:41:48 -0800 Subject: [PATCH 5/5] add crd format check for inference --- .../inference/test_crd_validation.py | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 test/unit_tests/inference/test_crd_validation.py diff --git a/test/unit_tests/inference/test_crd_validation.py b/test/unit_tests/inference/test_crd_validation.py new file mode 100644 index 00000000..98459367 --- /dev/null +++ b/test/unit_tests/inference/test_crd_validation.py @@ -0,0 +1,163 @@ +""" +Simplified unit tests for CRD format validation. + +This module contains essential tests to validate the basic format and structure +of the CRD YAML files used by the inference operator, focusing on core +Kubernetes CRD requirements. +""" + +import unittest +import yaml +from pathlib import Path + + +class TestCRDFormat(unittest.TestCase): + """Test class for validating essential CRD format requirements.""" + + def setUp(self): + """Set up test class with file paths.""" + self.base_path = Path(__file__).parent.parent.parent.parent + self.crd_path = self.base_path / "helm_chart" / "HyperPodHelmChart" / "charts" / "inference-operator" / "config" / "crd" + + self.crd_files = [ + self.crd_path / "inference.sagemaker.aws.amazon.com_inferenceendpointconfigs.yaml", + self.crd_path / "inference.sagemaker.aws.amazon.com_jumpstartmodels.yaml", + self.crd_path / "inference.sagemaker.aws.amazon.com_sagemakerendpointregistrations.yaml" + ] + + def test_crd_files_exist_and_valid_yaml(self): + """Test that all CRD files exist and have valid YAML syntax.""" + for file_path in self.crd_files: + with self.subTest(file=file_path.name): + # Check file exists + self.assertTrue(file_path.exists(), f"CRD file does not exist: {file_path}") + + # Check for tab characters (not allowed in YAML) + with open(file_path, 'r', encoding='utf-8') as f: + content_text = f.read() + if '\t' in content_text: + self.fail(f"File {file_path.name} contains tab characters. YAML should use spaces for indentation.") + + # Check valid YAML + with open(file_path, 'r', encoding='utf-8') as f: + try: + content = yaml.safe_load(f) + self.assertIsNotNone(content, f"YAML content is empty in {file_path.name}") + except yaml.YAMLError as e: + self.fail(f"Invalid YAML syntax in {file_path.name}: {e}") + + def test_required_crd_structure(self): + """Test that all CRD files have the required Kubernetes CRD structure.""" + for file_path in self.crd_files: + with self.subTest(file=file_path.name): + with open(file_path, 'r', encoding='utf-8') as f: + content = yaml.safe_load(f) + + # Check required top-level fields + required_fields = ['apiVersion', 'kind', 'metadata', 'spec'] + for field in required_fields: + self.assertIn(field, content, f"Missing required field '{field}' in {file_path.name}") + + # Verify this is a CustomResourceDefinition + self.assertEqual(content['apiVersion'], "apiextensions.k8s.io/v1", + f"Expected apiVersion 'apiextensions.k8s.io/v1' in {file_path.name}") + self.assertEqual(content['kind'], "CustomResourceDefinition", + f"Expected kind 'CustomResourceDefinition' in {file_path.name}") + + def test_crd_spec_structure(self): + """Test that CRD spec has required fields and basic structure.""" + for file_path in self.crd_files: + with self.subTest(file=file_path.name): + with open(file_path, 'r', encoding='utf-8') as f: + content = yaml.safe_load(f) + + spec = content.get('spec', {}) + + # Check required spec fields + required_spec_fields = ['group', 'names', 'scope', 'versions'] + for field in required_spec_fields: + self.assertIn(field, spec, f"Missing required spec field '{field}' in {file_path.name}") + + # Validate spec.group + self.assertEqual(spec['group'], "inference.sagemaker.aws.amazon.com", + f"Expected group 'inference.sagemaker.aws.amazon.com' in {file_path.name}") + + # Validate spec.scope + self.assertEqual(spec['scope'], "Namespaced", + f"Expected scope 'Namespaced' in {file_path.name}") + + def test_crd_names_structure(self): + """Test that CRD names section has required fields.""" + for file_path in self.crd_files: + with self.subTest(file=file_path.name): + with open(file_path, 'r', encoding='utf-8') as f: + content = yaml.safe_load(f) + + names = content.get('spec', {}).get('names', {}) + + # Check required names fields + required_names_fields = ['kind', 'listKind', 'plural', 'singular'] + for field in required_names_fields: + self.assertIn(field, names, f"Missing required names field '{field}' in {file_path.name}") + self.assertTrue(names[field], f"Empty value for names.{field} in {file_path.name}") + + def test_crd_versions_structure(self): + """Test that CRD versions are properly structured with required fields.""" + for file_path in self.crd_files: + with self.subTest(file=file_path.name): + with open(file_path, 'r', encoding='utf-8') as f: + content = yaml.safe_load(f) + + versions = content.get('spec', {}).get('versions', []) + + # Validate versions is a non-empty list + self.assertIsInstance(versions, list, f"spec.versions should be a list in {file_path.name}") + self.assertGreater(len(versions), 0, f"spec.versions should not be empty in {file_path.name}") + + # Check each version has required fields + for i, version in enumerate(versions): + required_version_fields = ['name', 'served', 'storage', 'schema'] + for field in required_version_fields: + self.assertIn(field, version, + f"Missing required field '{field}' in version {i} of {file_path.name}") + + # Validate schema has openAPIV3Schema + schema = version.get('schema', {}) + self.assertIn('openAPIV3Schema', schema, + f"Missing 'openAPIV3Schema' in version {i} schema of {file_path.name}") + + openapi_schema = schema.get('openAPIV3Schema', {}) + self.assertIn('type', openapi_schema, + f"Missing 'type' in openAPIV3Schema for version {i} of {file_path.name}") + self.assertEqual(openapi_schema['type'], 'object', + f"Expected 'type: object' in openAPIV3Schema for version {i} of {file_path.name}") + + def test_metadata_name_format(self): + """Test that metadata.name follows the expected CRD naming convention.""" + expected_names = { + 'inferenceendpointconfigs': 'inferenceendpointconfigs.inference.sagemaker.aws.amazon.com', + 'jumpstartmodels': 'jumpstartmodels.inference.sagemaker.aws.amazon.com', + 'sagemakerendpointregistrations': 'sagemakerendpointregistrations.inference.sagemaker.aws.amazon.com' + } + + for file_path in self.crd_files: + with self.subTest(file=file_path.name): + with open(file_path, 'r', encoding='utf-8') as f: + content = yaml.safe_load(f) + + name = content.get('metadata', {}).get('name', '') + + # Find expected name based on filename + expected_name = None + for key, value in expected_names.items(): + if key in file_path.name: + expected_name = value + break + + self.assertIsNotNone(expected_name, f"Could not determine expected name for {file_path.name}") + self.assertEqual(name, expected_name, + f"Expected metadata.name '{expected_name}' in {file_path.name}, got '{name}'") + + +if __name__ == '__main__': + unittest.main()