diff --git a/src/runpod_flash/__init__.py b/src/runpod_flash/__init__.py index 723d385c..4a6e5b3c 100644 --- a/src/runpod_flash/__init__.py +++ b/src/runpod_flash/__init__.py @@ -1,9 +1,11 @@ __version__ = "1.9.0" # x-release-please-version # Load .env vars from file before everything else -from dotenv import load_dotenv +# usecwd=True walks up from CWD (user's project) instead of from the +# package source file location, which matters for editable installs. +from dotenv import find_dotenv, load_dotenv -load_dotenv() +load_dotenv(find_dotenv(usecwd=True)) from .logger import setup_logging # noqa: E402 diff --git a/src/runpod_flash/core/resources/environment.py b/src/runpod_flash/core/resources/environment.py index e0088972..4897e75d 100644 --- a/src/runpod_flash/core/resources/environment.py +++ b/src/runpod_flash/core/resources/environment.py @@ -1,5 +1,5 @@ from typing import Dict, Optional -from dotenv import dotenv_values +from dotenv import dotenv_values, find_dotenv class EnvironmentVars: @@ -16,7 +16,9 @@ def _load_env(self) -> Dict[str, str]: Dict[str, str]: Dictionary containing environment variables from .env file """ # Use dotenv_values instead of load_dotenv to get only variables from .env - return dict(dotenv_values()) + # usecwd=True walks up from CWD (user's project) instead of from the + # package source file location, which matters for editable installs. + return dict(dotenv_values(find_dotenv(usecwd=True))) def get_env(self) -> Dict[str, str]: """ diff --git a/src/runpod_flash/core/resources/load_balancer_sls_resource.py b/src/runpod_flash/core/resources/load_balancer_sls_resource.py index ee0cf4ef..d31e2669 100644 --- a/src/runpod_flash/core/resources/load_balancer_sls_resource.py +++ b/src/runpod_flash/core/resources/load_balancer_sls_resource.py @@ -255,10 +255,10 @@ async def _do_deploy(self) -> "LoadBalancerSlsResource": return self try: - # Mark this endpoint as load-balanced (triggers auto-provisioning on boot) - if self.env is None: - self.env = {} - self.env["FLASH_ENDPOINT_TYPE"] = "lb" + # NOTE: FLASH_ENDPOINT_TYPE is NOT injected here. For flash deploy, + # the runtime resource_provisioner sets it. For flash run (live + # serverless), the worker must NOT see it — otherwise it triggers + # artifact unpacking which doesn't exist for live endpoints. # Call parent deploy (creates endpoint via RunPod API) log.debug(f"Deploying LB endpoint: {self.name}") diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index 3300a741..381c5363 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -259,20 +259,19 @@ def validate_python_version(cls, v: Optional[str]) -> Optional[str]: @property def config_hash(self) -> str: - """Get config hash excluding env and runtime-assigned fields. + """Get config hash excluding runtime-assigned fields. Prevents false drift from: - - Dynamic env vars computed at runtime - Runtime-assigned fields (template, templateId, aiKey, userId, etc.) - Only hashes user-specified configuration, not server-assigned state. + Hashes user-specified configuration including env vars. """ import hashlib import json resource_type = self.__class__.__name__ - # Exclude runtime fields, env, and id from hash + # Exclude runtime fields and id from hash exclude_fields = ( self.__class__.RUNTIME_FIELDS | self.__class__.EXCLUDED_HASH_FIELDS ) @@ -534,12 +533,24 @@ def _payload_exclude(self) -> Set[str]: @staticmethod def _build_template_update_payload( - template: PodTemplate, template_id: str + template: PodTemplate, + template_id: str, + *, + skip_env: bool = False, ) -> Dict[str, Any]: """Build saveTemplate payload from template model. Keep this to fields supported by saveTemplate to avoid passing endpoint-only fields to the template mutation. + + Args: + template: Template model with desired configuration. + template_id: ID of the template to update. + skip_env: When True, omit ``env`` from the payload so + saveTemplate preserves the existing template env vars. + This prevents removing platform-injected vars (e.g. + PORT, PORT_HEALTH on LB endpoints) when the user's + env hasn't actually changed. """ template_data = template.model_dump(exclude_none=True, mode="json") allowed_fields = { @@ -550,6 +561,8 @@ def _build_template_update_payload( "env", "readme", } + if skip_env: + allowed_fields.discard("env") payload = { key: value for key, value in template_data.items() if key in allowed_fields } @@ -643,6 +656,56 @@ def _get_module_path(self) -> Optional[str]: except Exception: return None + def _inject_template_env(self, key: str, value: str) -> None: + """Append a KeyValuePair to self.template.env if the key isn't already present. + + This injects runtime env vars directly into the template without + mutating self.env, which would cause false config drift on subsequent + deploys. + """ + if self.template is None: + return + if self.template.env is None: + self.template.env = [] + existing_keys = {kv.key for kv in self.template.env} + if key not in existing_keys: + self.template.env.append(KeyValuePair(key=key, value=value)) + + def _inject_runtime_template_vars(self) -> None: + """Inject runtime env vars into template.env without mutating self.env. + + For QB endpoints making remote calls: injects RUNPOD_API_KEY. + For LB endpoints: injects FLASH_MODULE_PATH. + + Called by both _do_deploy (initial) and update (env changes) so + runtime vars survive template updates. + """ + env_dict = self.env or {} + + if self.type == ServerlessType.QB: + if self._check_makes_remote_calls(): + if "RUNPOD_API_KEY" not in env_dict: + from runpod_flash.core.credentials import get_api_key + + api_key = get_api_key() + if api_key: + self._inject_template_env("RUNPOD_API_KEY", api_key) + log.debug( + f"{self.name}: Injected RUNPOD_API_KEY for remote calls " + f"(makes_remote_calls=True)" + ) + else: + log.warning( + f"{self.name}: makes_remote_calls=True but RUNPOD_API_KEY not set. " + f"Remote calls to other endpoints will fail." + ) + + elif self.type == ServerlessType.LB: + module_path = self._get_module_path() + if module_path and "FLASH_MODULE_PATH" not in env_dict: + self._inject_template_env("FLASH_MODULE_PATH", module_path) + log.debug(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}") + async def _do_deploy(self) -> "DeployableResource": """ Deploys the serverless resource using the provided configuration. @@ -658,43 +721,7 @@ async def _do_deploy(self) -> "DeployableResource": log.debug(f"{self} exists") return self - # Inject API key for queue-based endpoints that make remote calls - if self.type == ServerlessType.QB: - env_dict = self.env or {} - - # Check if this resource makes remote calls (from build manifest) - makes_remote_calls = self._check_makes_remote_calls() - - if makes_remote_calls: - # Inject RUNPOD_API_KEY if not already set - if "RUNPOD_API_KEY" not in env_dict: - from runpod_flash.core.credentials import get_api_key - - api_key = get_api_key() - if api_key: - env_dict["RUNPOD_API_KEY"] = api_key - log.debug( - f"{self.name}: Injected RUNPOD_API_KEY for remote calls " - f"(makes_remote_calls=True)" - ) - else: - log.warning( - f"{self.name}: makes_remote_calls=True but RUNPOD_API_KEY not set. " - f"Remote calls to other endpoints will fail." - ) - - self.env = env_dict - - # Inject module path for load-balanced endpoints - elif self.type == ServerlessType.LB: - env_dict = self.env or {} - - module_path = self._get_module_path() - if module_path and "FLASH_MODULE_PATH" not in env_dict: - env_dict["FLASH_MODULE_PATH"] = module_path - log.debug(f"{self.name}: Injected FLASH_MODULE_PATH={module_path}") - - self.env = env_dict + self._inject_runtime_template_vars() # Ensure network volume is deployed first await self._ensure_network_volume_deployed() @@ -764,8 +791,29 @@ async def update(self, new_config: "ServerlessResource") -> "ServerlessResource" if new_config.template: if resolved_template_id: + # Skip env in the template payload when the user's env + # hasn't changed. This lets the platform keep vars it + # injected (e.g. PORT, PORT_HEALTH on LB endpoints) + # and avoids a spurious rolling release. + # + # Also check template.env: if env is empty but the + # caller provided explicit template env entries, those + # must not be silently dropped. + env_unchanged = self.env == new_config.env + has_explicit_template_env = ( + not new_config.env and new_config.template.env is not None + ) + skip_env = env_unchanged and not has_explicit_template_env + + if not skip_env: + # Inject runtime vars (RUNPOD_API_KEY, FLASH_MODULE_PATH) + # so they survive the template env overwrite. + new_config._inject_runtime_template_vars() + template_payload = self._build_template_update_payload( - new_config.template, resolved_template_id + new_config.template, + resolved_template_id, + skip_env=skip_env, ) await client.update_template(template_payload) log.debug( diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 22bf4b34..581dfdc8 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -22,7 +22,7 @@ from runpod_flash.core.resources.gpu import GpuGroup from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter -from runpod_flash.core.resources.template import PodTemplate +from runpod_flash.core.resources.template import KeyValuePair, PodTemplate class TestServerlessResource: @@ -924,8 +924,6 @@ def test_serverless_endpoint_with_existing_template(self): def test_serverless_endpoint_template_env_override(self): """Test ServerlessEndpoint overrides template env vars.""" - from runpod_flash.core.resources.template import PodTemplate, KeyValuePair - template = PodTemplate( name="existing-template", imageName="test/image:v1", @@ -1421,3 +1419,526 @@ def test_python_version_in_hashed_fields(self): def test_python_version_in_input_only(self): input_only = self._get_class_set("_input_only") assert "python_version" in input_only + + +class TestInjectTemplateEnv: + """Test _inject_template_env helper and _do_deploy env non-mutation.""" + + def _make_resource_with_template(self, **overrides): + """Create a ServerlessEndpoint with a template for injection tests.""" + defaults = { + "name": "inject-test", + "imageName": "test:latest", + "env": {"USER_VAR": "user_value"}, + "flashboot": False, + } + defaults.update(overrides) + return ServerlessEndpoint(**defaults) + + def test_inject_template_env_adds_key_value_pair(self): + """_inject_template_env adds a KeyValuePair to template.env.""" + resource = self._make_resource_with_template() + assert resource.template is not None + + original_len = len(resource.template.env) + resource._inject_template_env("NEW_KEY", "new_value") + + assert len(resource.template.env) == original_len + 1 + added = resource.template.env[-1] + assert added.key == "NEW_KEY" + assert added.value == "new_value" + + def test_inject_template_env_is_idempotent(self): + """_inject_template_env does not add duplicate keys.""" + resource = self._make_resource_with_template() + assert resource.template is not None + + resource._inject_template_env("DEDUP_KEY", "first") + resource._inject_template_env("DEDUP_KEY", "second") + + matching = [kv for kv in resource.template.env if kv.key == "DEDUP_KEY"] + assert len(matching) == 1 + assert matching[0].value == "first" + + def test_inject_template_env_skips_when_no_template(self): + """_inject_template_env is a no-op when template is None.""" + resource = ServerlessResource(name="no-template") + resource.template = None + + # Should not raise + resource._inject_template_env("KEY", "value") + + def test_inject_template_env_initializes_empty_env_list(self): + """_inject_template_env handles template with None env list.""" + resource = self._make_resource_with_template() + resource.template.env = None + + resource._inject_template_env("INIT_KEY", "init_value") + + assert len(resource.template.env) == 1 + assert resource.template.env[0].key == "INIT_KEY" + + @pytest.mark.asyncio + async def test_do_deploy_does_not_mutate_self_env(self): + """_do_deploy should not modify self.env (prevents false config drift).""" + resource = self._make_resource_with_template( + env={"LOG_LEVEL": "INFO"}, + ) + env_before = dict(resource.env) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "endpoint-env-test", + "name": "inject-test", + "templateId": "tpl-env-test", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + with patch.object( + ServerlessResource, "is_deployed", return_value=False + ): + with patch.object( + ServerlessResource, + "_check_makes_remote_calls", + return_value=True, + ): + with patch.dict(os.environ, {"RUNPOD_API_KEY": "test-key-123"}): + await resource._do_deploy() + + assert resource.env == env_before + + @pytest.mark.asyncio + async def test_do_deploy_injects_api_key_into_template_env(self): + """_do_deploy should inject RUNPOD_API_KEY into template.env for QB endpoints.""" + resource = self._make_resource_with_template( + env={"LOG_LEVEL": "INFO"}, + ) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "endpoint-inject-test", + "name": "inject-test", + "templateId": "tpl-inject-test", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + with patch.object( + ServerlessResource, "is_deployed", return_value=False + ): + with patch.object( + ServerlessResource, + "_check_makes_remote_calls", + return_value=True, + ): + with patch.dict(os.environ, {"RUNPOD_API_KEY": "test-key-456"}): + await resource._do_deploy() + + # The API key should have been in the payload sent to save_endpoint + # via the template env, not via self.env + payload = mock_client.save_endpoint.call_args.args[0] + template_env = payload.get("template", {}).get("env", []) + api_key_entries = [e for e in template_env if e["key"] == "RUNPOD_API_KEY"] + assert len(api_key_entries) == 1 + assert api_key_entries[0]["value"] == "test-key-456" + + @pytest.mark.asyncio + async def test_do_deploy_lb_injects_module_path_into_template_env(self): + """_do_deploy should inject FLASH_MODULE_PATH into template.env for LB endpoints.""" + from runpod_flash.core.resources.load_balancer_sls_resource import ( + LoadBalancerSlsResource, + ) + + resource = LoadBalancerSlsResource( + name="lb-inject-test", + imageName="test:latest", + env={"LOG_LEVEL": "INFO"}, + flashboot=False, + ) + env_before = dict(resource.env) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "endpoint-lb-test", + "name": "lb-inject-test", + "templateId": "tpl-lb-test", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + with patch.object( + LoadBalancerSlsResource, "is_deployed", return_value=False + ): + with patch.object( + ServerlessResource, + "_get_module_path", + return_value="myapp.handler", + ): + await resource._do_deploy() + + # self.env should not be mutated + assert resource.env == env_before + + # FLASH_MODULE_PATH should be in template env + payload = mock_client.save_endpoint.call_args.args[0] + template_env = payload.get("template", {}).get("env", []) + module_entries = [e for e in template_env if e["key"] == "FLASH_MODULE_PATH"] + assert len(module_entries) == 1 + assert module_entries[0]["value"] == "myapp.handler" + + # FLASH_ENDPOINT_TYPE should NOT be injected here — it's set by the + # runtime resource_provisioner for flash deploy, not by _do_deploy + type_entries = [e for e in template_env if e["key"] == "FLASH_ENDPOINT_TYPE"] + assert len(type_entries) == 0 + + +class TestBuildTemplateUpdatePayload: + """Test _build_template_update_payload skip_env behavior.""" + + def test_payload_includes_env_by_default(self): + """Template update payload includes env when skip_env is False.""" + template = PodTemplate( + name="test-template", + imageName="test:latest", + env=[KeyValuePair(key="MY_VAR", value="my_val")], + ) + payload = ServerlessResource._build_template_update_payload(template, "tpl-123") + assert "env" in payload + assert payload["env"] == [{"key": "MY_VAR", "value": "my_val"}] + + def test_payload_excludes_env_when_skip_env_true(self): + """Template update payload omits env when skip_env is True. + + This preserves platform-injected vars (e.g. PORT, PORT_HEALTH) + on the existing template. + """ + template = PodTemplate( + name="test-template", + imageName="test:latest", + env=[KeyValuePair(key="MY_VAR", value="my_val")], + ) + payload = ServerlessResource._build_template_update_payload( + template, "tpl-123", skip_env=True + ) + assert "env" not in payload + # Other fields should still be present + assert payload["imageName"] == "test:latest" + assert payload["id"] == "tpl-123" + + @pytest.mark.asyncio + async def test_update_skips_env_when_unchanged(self): + """update() omits env from template payload when env hasn't changed.""" + env = {"LOG_LEVEL": "INFO"} + old_resource = ServerlessEndpoint( + name="update-test", + imageName="test:latest", + env=env, + flashboot=False, + ) + old_resource.id = "ep-123" + old_resource.templateId = "tpl-123" + + new_resource = ServerlessEndpoint( + name="update-test", + imageName="test:latest", + env=env, + flashboot=False, + workersMax=5, + ) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "ep-123", + "name": "update-test", + "templateId": "tpl-123", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + mock_client.update_template = AsyncMock(return_value={}) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + await old_resource.update(new_resource) + + # update_template was called, but env should NOT be in the payload + assert mock_client.update_template.called + template_payload = mock_client.update_template.call_args.args[0] + assert "env" not in template_payload + + @pytest.mark.asyncio + async def test_update_includes_env_when_changed(self): + """update() includes env in template payload when env changed.""" + old_resource = ServerlessEndpoint( + name="update-test", + imageName="test:latest", + env={"LOG_LEVEL": "INFO"}, + flashboot=False, + ) + old_resource.id = "ep-123" + old_resource.templateId = "tpl-123" + + new_resource = ServerlessEndpoint( + name="update-test", + imageName="test:latest", + env={"LOG_LEVEL": "DEBUG", "NEW_VAR": "new_val"}, + flashboot=False, + ) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "ep-123", + "name": "update-test", + "templateId": "tpl-123", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + mock_client.update_template = AsyncMock(return_value={}) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + await old_resource.update(new_resource) + + # update_template was called WITH env since it changed + assert mock_client.update_template.called + template_payload = mock_client.update_template.call_args.args[0] + assert "env" in template_payload + + @pytest.mark.asyncio + async def test_update_injects_runtime_vars_when_env_changed(self): + """update() injects RUNPOD_API_KEY into template.env when env changed. + + Without this, runtime-injected vars (set during _do_deploy) would be + lost when update() overwrites the template env. + """ + old_resource = ServerlessEndpoint( + name="update-inject-test", + imageName="test:latest", + env={"LOG_LEVEL": "INFO"}, + flashboot=False, + ) + old_resource.id = "ep-inject" + old_resource.templateId = "tpl-inject" + + new_resource = ServerlessEndpoint( + name="update-inject-test", + imageName="test:latest", + env={"LOG_LEVEL": "DEBUG"}, + flashboot=False, + ) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "ep-inject", + "name": "update-inject-test", + "templateId": "tpl-inject", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + mock_client.update_template = AsyncMock(return_value={}) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + with patch.object( + ServerlessResource, + "_check_makes_remote_calls", + return_value=True, + ): + with patch.dict(os.environ, {"RUNPOD_API_KEY": "inject-key"}): + await old_resource.update(new_resource) + + template_payload = mock_client.update_template.call_args.args[0] + env_entries = template_payload.get("env", []) + api_key_entries = [e for e in env_entries if e["key"] == "RUNPOD_API_KEY"] + assert len(api_key_entries) == 1 + assert api_key_entries[0]["value"] == "inject-key" + + @pytest.mark.asyncio + async def test_update_skips_runtime_injection_when_env_unchanged(self): + """update() does not inject runtime vars when env is unchanged. + + When skip_env=True, the template env payload is omitted entirely, + so runtime vars already on the platform are preserved as-is. + """ + env = {"LOG_LEVEL": "INFO"} + old_resource = ServerlessEndpoint( + name="update-no-inject", + imageName="test:latest", + env=env, + flashboot=False, + ) + old_resource.id = "ep-no-inject" + old_resource.templateId = "tpl-no-inject" + + new_resource = ServerlessEndpoint( + name="update-no-inject", + imageName="test:latest", + env=env, + flashboot=False, + ) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "ep-no-inject", + "name": "update-no-inject", + "templateId": "tpl-no-inject", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + mock_client.update_template = AsyncMock(return_value={}) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + with patch.object( + ServerlessResource, + "_check_makes_remote_calls", + return_value=True, + ): + with patch.dict(os.environ, {"RUNPOD_API_KEY": "inject-key"}): + await old_resource.update(new_resource) + + # env should be omitted from template payload (skip_env=True) + template_payload = mock_client.update_template.call_args.args[0] + assert "env" not in template_payload + + @pytest.mark.asyncio + async def test_update_includes_env_for_explicit_template_env(self): + """update() sends env when caller provides explicit template.env with empty env. + + Even if self.env == new_config.env (both empty), explicit template.env + entries must not be silently dropped. + """ + old_resource = ServerlessEndpoint( + name="update-tpl-env", + imageName="test:latest", + env={}, + flashboot=False, + ) + old_resource.id = "ep-tpl-env" + old_resource.templateId = "tpl-tpl-env" + + new_resource = ServerlessEndpoint( + name="update-tpl-env", + imageName="test:latest", + env={}, + flashboot=False, + template=PodTemplate( + name="explicit-tpl", + imageName="test:latest", + env=[KeyValuePair(key="EXPLICIT_VAR", value="explicit_val")], + ), + ) + + mock_client = AsyncMock() + mock_client.save_endpoint = AsyncMock( + return_value={ + "id": "ep-tpl-env", + "name": "update-tpl-env", + "templateId": "tpl-tpl-env", + "gpuIds": "AMPERE_48", + "allowedCudaVersions": "", + } + ) + mock_client.update_template = AsyncMock(return_value={}) + + with patch( + "runpod_flash.core.resources.serverless.RunpodGraphQLClient" + ) as mock_client_class: + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + with patch.object( + ServerlessResource, + "_ensure_network_volume_deployed", + new=AsyncMock(), + ): + await old_resource.update(new_resource) + + template_payload = mock_client.update_template.call_args.args[0] + assert "env" in template_payload + env_entries = template_payload["env"] + explicit = [e for e in env_entries if e["key"] == "EXPLICIT_VAR"] + assert len(explicit) == 1 diff --git a/tests/unit/test_dotenv_loading.py b/tests/unit/test_dotenv_loading.py index 90e2c903..1a080b1f 100644 --- a/tests/unit/test_dotenv_loading.py +++ b/tests/unit/test_dotenv_loading.py @@ -32,9 +32,9 @@ def test_dotenv_loads_before_imports(self): logger_import_line = None for i, line in enumerate(lines): - if "from dotenv import load_dotenv" in line: + if "from dotenv import" in line and "load_dotenv" in line: dotenv_import_line = i - elif line.strip() == "load_dotenv()": + elif "load_dotenv(" in line.strip() and "import" not in line: dotenv_call_line = i elif "from .logger import setup_logging" in line: logger_import_line = i @@ -331,8 +331,8 @@ def test_dotenv_import_present_in_init(self): content = init_file.read_text() # Verify dotenv is imported and called - assert "from dotenv import load_dotenv" in content - assert "load_dotenv()" in content + assert "from dotenv import find_dotenv, load_dotenv" in content + assert "load_dotenv(find_dotenv(usecwd=True))" in content # Verify dotenv is imported before any other module imports lines = content.split("\n") @@ -341,7 +341,7 @@ def test_dotenv_import_present_in_init(self): ] # First import line should be the dotenv import - assert "from dotenv import load_dotenv" in import_lines[0] + assert "from dotenv import find_dotenv, load_dotenv" in import_lines[0] @patch.dict(os.environ, {}, clear=True) def test_clean_environment_dotenv_loading(self):