Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/runpod_flash/core/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
)
from .serverless_cpu import CpuServerlessEndpoint
from .template import PodTemplate
from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS
from .network_volume import NetworkVolume
from .datacenter import DataCenter, CPU_DATACENTERS
from .load_balancer_sls_resource import (
CpuLoadBalancerSlsResource,
LoadBalancerSlsResource,
Expand Down
52 changes: 52 additions & 0 deletions src/runpod_flash/core/resources/datacenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from enum import Enum


class DataCenter(str, Enum):
"""Enum representing available RunPod data centers.

NOTE: these are only datacenters with storage support, and s3 API support.
"""

# north america
US_CA_2 = "US-CA-2"
US_IL_1 = "US-IL-1"
US_KS_2 = "US-KS-2"
US_MO_1 = "US-MO-1"
US_MO_2 = "US-MO-2"
US_NC_2 = "US-NC-2"
US_NE_1 = "US-NE-1"
US_WA_1 = "US-WA-1"

# europe
EU_CZ_1 = "EU-CZ-1"
EU_RO_1 = "EU-RO-1"
EUR_NO_1 = "EUR-NO-1"

@classmethod
def from_string(cls, value: str) -> "DataCenter":
"""Parse a datacenter ID string into a DataCenter enum.

Accepts the canonical form (e.g. "EU-RO-1") as well as common
variations like lowercase or underscore-separated.
"""
normalized = value.strip().upper().replace("_", "-")
try:
return cls(normalized)
except ValueError:
valid = ", ".join(dc.value for dc in cls)
raise ValueError(
f"Unknown datacenter '{value}'. Valid datacenters: {valid}"
)

@classmethod
def all(cls) -> list["DataCenter"]:
"""Return all datacenters."""
return list(cls)


# data centers that support CPU serverless endpoints
CPU_DATACENTERS: frozenset[DataCenter] = frozenset(
{
DataCenter.EU_RO_1,
}
)
49 changes: 1 addition & 48 deletions src/runpod_flash/core/resources/network_volume.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import hashlib
import logging
from enum import Enum
from typing import Optional, Dict, Any

from pydantic import (
Expand All @@ -10,6 +9,7 @@
field_serializer,
model_validator,
)
from .datacenter import DataCenter

from ..api.runpod import RunpodRestClient
from ..urls import RUNPOD_CONSOLE_URL
Expand All @@ -19,53 +19,6 @@
log = logging.getLogger(__name__)


class DataCenter(str, Enum):
"""Enum representing available RunPod data centers."""

# north america
US_CA_2 = "US-CA-2"
US_GA_2 = "US-GA-2"
US_IL_1 = "US-IL-1"
US_KS_2 = "US-KS-2"
US_MD_1 = "US-MD-1"
US_MO_1 = "US-MO-1"
US_MO_2 = "US-MO-2"
US_NC_1 = "US-NC-1"
US_NC_2 = "US-NC-2"
US_NE_1 = "US-NE-1"
US_WA_1 = "US-WA-1"

# europe
EU_CZ_1 = "EU-CZ-1"
EU_RO_1 = "EU-RO-1"
EUR_IS_1 = "EUR-IS-1"
EUR_NO_1 = "EUR-NO-1"

@classmethod
def from_string(cls, value: str) -> "DataCenter":
"""Parse a datacenter ID string into a DataCenter enum.

Accepts the canonical form (e.g. "EU-RO-1") as well as common
variations like lowercase or underscore-separated.
"""
normalized = value.strip().upper().replace("_", "-")
try:
return cls(normalized)
except ValueError:
valid = ", ".join(dc.value for dc in cls)
raise ValueError(
f"Unknown datacenter '{value}'. Valid datacenters: {valid}"
)


# data centers that support CPU serverless endpoints
CPU_DATACENTERS: frozenset[DataCenter] = frozenset(
{
DataCenter.EU_RO_1,
}
)


class NetworkVolume(DeployableResource):
"""
NetworkVolume resource for creating and managing Runpod network volumes.
Expand Down
3 changes: 2 additions & 1 deletion src/runpod_flash/core/resources/serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
)
from .cpu import CpuInstanceType
from .gpu import GpuGroup, GpuType
from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS
from .network_volume import NetworkVolume
from .datacenter import DataCenter, CPU_DATACENTERS
from .request_logs import QBRequestLogBatch, QBRequestLogFetcher, QBRequestLogPhase
from .worker_availability_diagnostic import WorkerAvailabilityDiagnostic
from .template import KeyValuePair, PodTemplate
Expand Down
8 changes: 7 additions & 1 deletion src/runpod_flash/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .core.resources.constants import DEFAULT_WORKERS_MAX, DEFAULT_WORKERS_MIN
from .core.resources.cpu import CpuInstanceType
from .core.resources.gpu import GpuGroup, GpuType
from .core.resources.network_volume import DataCenter, NetworkVolume
from .core.resources.network_volume import NetworkVolume
from .core.resources.datacenter import DataCenter
from .core.resources.serverless import CudaVersion, ServerlessScalerType
from .core.resources.template import PodTemplate

Expand Down Expand Up @@ -413,6 +414,11 @@ def __init__(
if not self._is_cpu and self._gpu is None and not self.is_client:
self._gpu = [GpuGroup.ANY]

# if not in pure client mode, make sure default datacenters are set
# not CPU though, that gets pinned to specific datacenters
if not self._is_cpu and not self.is_client and not self.datacenter:
self.datacenter = DataCenter.all()

# lb routes registered via .get()/.post()/etc (decorator mode only)
self._routes: List[Dict[str, Any]] = []

Expand Down
10 changes: 4 additions & 6 deletions tests/unit/cli/commands/build_utils/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,7 @@ def test_extract_deployment_config_includes_network_volume():

resource_py = project_dir / "resource.py"
resource_py.write_text(
"from runpod_flash import NetworkVolume\n"
"from runpod_flash.core.resources.network_volume import DataCenter\n"
"from runpod_flash import NetworkVolume, DataCenter\n"
"\n"
"class gpu_config:\n"
' imageName = "test-image"\n'
Expand Down Expand Up @@ -631,8 +630,7 @@ def test_extract_deployment_config_includes_network_volumes():

resource_py = project_dir / "resource.py"
resource_py.write_text(
"from runpod_flash import NetworkVolume\n"
"from runpod_flash.core.resources.network_volume import DataCenter\n"
"from runpod_flash import NetworkVolume, DataCenter\n"
"\n"
"class gpu_config:\n"
' imageName = "test-image"\n'
Expand All @@ -645,7 +643,7 @@ def test_extract_deployment_config_includes_network_volumes():
" NetworkVolume(\n"
' name="vol-us",\n'
" size=200,\n"
" dataCenterId=DataCenter.US_GA_2,\n"
" dataCenterId=DataCenter.US_CA_2,\n"
" ),\n"
" ]\n"
)
Expand Down Expand Up @@ -675,7 +673,7 @@ def test_extract_deployment_config_includes_network_volumes():
assert config["networkVolumes"][0]["dataCenterId"] == "EU-RO-1"
assert config["networkVolumes"][1]["name"] == "vol-us"
assert config["networkVolumes"][1]["size"] == 200
assert config["networkVolumes"][1]["dataCenterId"] == "US-GA-2"
assert config["networkVolumes"][1]["dataCenterId"] == "US-CA-2"
assert "networkVolume" not in config


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/resources/test_network_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest
from pydantic import ValidationError

from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter
from runpod_flash.core.resources.network_volume import NetworkVolume
from runpod_flash.core.resources.datacenter import DataCenter


class TestNetworkVolumeIdempotent:
Expand Down
53 changes: 27 additions & 26 deletions tests/unit/resources/test_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from runpod_flash.core.resources.serverless_cpu import CpuServerlessEndpoint
from runpod_flash.core.resources.gpu import GpuGroup
from runpod_flash.core.resources.cpu import CpuInstanceType
from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter
from runpod_flash.core.resources.network_volume import NetworkVolume
from runpod_flash.core.resources.datacenter import DataCenter
from runpod_flash.core.resources.request_logs import (
QBRequestLogBatch,
QBRequestLogPhase,
Expand Down Expand Up @@ -220,12 +221,12 @@ class TestMultiVolumeDeployPath:
@pytest.mark.asyncio
async def test_multi_volume_deploys_all_and_collects_ids(self):
vol_a = NetworkVolume(name="vol-a", size=50, dataCenterId=DataCenter.EU_RO_1)
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_GA_2)
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_CA_2)

serverless = ServerlessResource(
name="test",
networkVolumes=[vol_a, vol_b],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
)

async def fake_deploy(self_vol):
Expand All @@ -242,12 +243,12 @@ async def fake_deploy(self_vol):
async def test_multi_volume_skips_already_created(self):
vol_a = NetworkVolume(name="vol-a", size=50, dataCenterId=DataCenter.EU_RO_1)
vol_a.id = "vol-aaa"
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_GA_2)
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_CA_2)

serverless = ServerlessResource(
name="test",
networkVolumes=[vol_a, vol_b],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
)

deploy_calls = []
Expand Down Expand Up @@ -330,7 +331,7 @@ def test_single_volume_payload_uses_singular_field(self):
def test_multi_volume_drift_detection(self):
"""Changing networkVolumes changes the config hash."""
vol_a = NetworkVolume(name="vol-a", size=50, dataCenterId=DataCenter.EU_RO_1)
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_GA_2)
vol_b = NetworkVolume(name="vol-b", size=50, dataCenterId=DataCenter.US_CA_2)

s1 = ServerlessResource(
name="test",
Expand All @@ -340,7 +341,7 @@ def test_multi_volume_drift_detection(self):
s2 = ServerlessResource(
name="test",
networkVolumes=[vol_a, vol_b],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
)

assert s1.config_hash != s2.config_hash
Expand Down Expand Up @@ -470,9 +471,9 @@ def test_datacenter_multiple_values(self):
"""Test datacenter accepts a list of DataCenter values."""
serverless = ServerlessResource(
name="test",
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
)
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_2]
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_CA_2]

def test_datacenter_string_value(self):
"""Test datacenter accepts string values."""
Expand All @@ -481,8 +482,8 @@ def test_datacenter_string_value(self):

def test_datacenter_string_list(self):
"""Test datacenter accepts list of strings."""
serverless = ServerlessResource(name="test", datacenter=["EU-RO-1", "US-GA-2"])
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_GA_2]
serverless = ServerlessResource(name="test", datacenter=["EU-RO-1", "US-CA-2"])
assert serverless.datacenter == [DataCenter.EU_RO_1, DataCenter.US_CA_2]

def test_datacenter_invalid_string_raises(self):
"""Test that an invalid datacenter string raises ValueError."""
Expand All @@ -498,9 +499,9 @@ def test_locations_synced_from_multi_datacenter(self):
"""Test locations field gets synced from multiple datacenters."""
serverless = ServerlessResource(
name="test",
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
)
assert serverless.locations == "EU-RO-1,US-GA-2"
assert serverless.locations == "EU-RO-1,US-CA-2"

def test_no_datacenter_no_locations(self):
"""Test that no datacenter means no locations restriction."""
Expand All @@ -509,9 +510,9 @@ def test_no_datacenter_no_locations(self):

def test_explicit_locations_not_overridden(self):
"""Test explicit locations field is not overridden."""
serverless = ServerlessResource(name="test", locations="US-GA-2")
serverless = ServerlessResource(name="test", locations="US-CA-2")

assert serverless.locations == "US-GA-2"
assert serverless.locations == "US-CA-2"

def test_datacenter_validation_matching_datacenters(self):
"""Test that matching datacenters between endpoint and volume work."""
Expand All @@ -525,7 +526,7 @@ def test_datacenter_validation_matching_datacenters(self):

def test_datacenter_validation_volume_not_in_dc_list(self):
"""Test that a volume DC not in endpoint's DC list raises an error."""
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_GA_2)
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_CA_2)
with pytest.raises(
ValueError,
match="Network volume datacenter.*is not in the endpoint's datacenter list",
Expand All @@ -536,9 +537,9 @@ def test_datacenter_validation_volume_not_in_dc_list(self):

def test_volume_dc_allowed_when_no_datacenter_set(self):
"""Test that any volume DC is allowed when no datacenter restriction is set."""
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_GA_2)
volume = NetworkVolume(name="test-volume", dataCenterId=DataCenter.US_CA_2)
serverless = ServerlessResource(name="test", networkVolume=volume)
assert serverless.networkVolume.dataCenterId == DataCenter.US_GA_2
assert serverless.networkVolume.dataCenterId == DataCenter.US_CA_2

def test_no_flashboot_keeps_name(self):
"""Test flashboot=False keeps original name."""
Expand Down Expand Up @@ -629,7 +630,7 @@ def test_single_volume_compat(self):
def test_multiple_volumes_via_list(self):
"""Test networkVolumes accepts multiple volumes."""
v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1)
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_GA_2)
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_CA_2)
s = ServerlessResource(name="test", networkVolumes=[v1, v2])
assert len(s.networkVolumes) == 2
assert s.networkVolume is v1
Expand All @@ -643,7 +644,7 @@ def test_duplicate_dc_raises(self):

def test_volumes_dc_outside_endpoint_dc_raises(self):
"""Test volume DC not in endpoint's DC list raises."""
vol = NetworkVolume(name="v1", dataCenterId=DataCenter.US_GA_2)
vol = NetworkVolume(name="v1", dataCenterId=DataCenter.US_CA_2)
with pytest.raises(
ValueError,
match="is not in the endpoint's datacenter list",
Expand All @@ -657,10 +658,10 @@ def test_volumes_dc_outside_endpoint_dc_raises(self):
def test_volumes_dc_within_endpoint_dc_list(self):
"""Test volume DCs all within endpoint DC list works."""
v1 = NetworkVolume(name="v1", dataCenterId=DataCenter.EU_RO_1)
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_GA_2)
v2 = NetworkVolume(name="v2", dataCenterId=DataCenter.US_CA_2)
s = ServerlessResource(
name="test",
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
networkVolumes=[v1, v2],
)
assert len(s.networkVolumes) == 2
Expand All @@ -684,7 +685,7 @@ def test_cpu_endpoint_in_unsupported_dc_raises(self):
CpuServerlessEndpoint(
name="test-cpu",
imageName="test/cpu:latest",
datacenter=DataCenter.US_GA_2,
datacenter=DataCenter.US_CA_2,
)

def test_cpu_endpoint_mixed_dcs_raises(self):
Expand All @@ -693,7 +694,7 @@ def test_cpu_endpoint_mixed_dcs_raises(self):
CpuServerlessEndpoint(
name="test-cpu",
imageName="test/cpu:latest",
datacenter=[DataCenter.EU_RO_1, DataCenter.US_GA_2],
datacenter=[DataCenter.EU_RO_1, DataCenter.US_CA_2],
)

def test_cpu_endpoint_no_datacenter_ok(self):
Expand All @@ -708,9 +709,9 @@ def test_gpu_endpoint_any_dc_ok(self):
"""Test GPU endpoint in any datacenter is allowed."""
serverless = ServerlessResource(
name="test-gpu",
datacenter=DataCenter.US_GA_2,
datacenter=DataCenter.US_CA_2,
)
assert serverless.datacenter == [DataCenter.US_GA_2]
assert serverless.datacenter == [DataCenter.US_CA_2]


class TestMinCudaVersion:
Expand Down
Loading
Loading