Skip to content

Commit 2222a63

Browse files
authored
Add API for SSH proxy (#3646)
Part-of: #3644
1 parent 808ff64 commit 2222a63

File tree

23 files changed

+827
-87
lines changed

23 files changed

+827
-87
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ markers = [
126126
]
127127
env = [
128128
"DSTACK_CLI_RICH_FORCE_TERMINAL=0",
129+
"DSTACK_SSHPROXY_API_TOKEN=test-token",
129130
]
130131
filterwarnings = [
131132
# testcontainers modules use deprecated decorators – nothing we can do:
@@ -142,6 +143,7 @@ dev = [
142143
"pytest-httpbin>=2.1.0",
143144
"pytest-socket>=0.7.0",
144145
"pytest-env>=1.1.0",
146+
"pytest-unordered>=0.7.0",
145147
"httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3
146148
"requests-mock>=1.12.1",
147149
"openai>=1.68.2",

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def run_job(
130130
commands = get_docker_commands(
131131
[run.run_spec.ssh_key_pub.strip(), project_ssh_public_key.strip()]
132132
)
133-
# There is a one jump pod per Kubernetes backend that is used
134-
# as an ssh proxy jump to connect to all other services in Kubernetes.
133+
# There is one jump pod per project that is used as an ssh proxy jump to connect
134+
# to all job pods of the same project.
135135
# The service is created here and configured later in update_provisioning_data()
136136
jump_pod_name = f"dstack-{run.project_name}-ssh-jump-pod"
137137
jump_pod_service_name = _get_pod_service_name(jump_pod_name)

src/dstack/_internal/server/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
runs,
4646
secrets,
4747
server,
48+
sshproxy,
4849
templates,
4950
users,
5051
volumes,
@@ -255,6 +256,7 @@ def register_routes(app: FastAPI, ui: bool = True):
255256
app.include_router(events.root_router)
256257
app.include_router(templates.router)
257258
app.include_router(exports.project_router)
259+
app.include_router(sshproxy.router)
258260

259261
@app.exception_handler(ForbiddenError)
260262
async def forbidden_error_handler(request: Request, exc: ForbiddenError):

src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
emit_job_status_change_event,
6060
get_job_provisioning_data,
6161
get_job_runtime_data,
62+
get_job_spec,
6263
)
6364
from dstack._internal.server.services.locking import get_locker
6465
from dstack._internal.server.services.logging import fmt
@@ -803,7 +804,7 @@ async def _detach_volumes_from_job_instance(
803804
jpd: JobProvisioningData,
804805
run_termination_reason: Optional[RunTerminationReason],
805806
) -> _VolumeDetachResult:
806-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
807+
job_spec = get_job_spec(job_model)
807808
backend = await backends_services.get_project_backend_by_type(
808809
project=instance_model.project,
809810
backend_type=jpd.backend,

src/dstack/_internal/server/background/scheduled_tasks/probes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlalchemy.orm import joinedload
1313

1414
from dstack._internal.core.errors import SSHError
15-
from dstack._internal.core.models.runs import JobSpec, JobStatus, ProbeSpec
15+
from dstack._internal.core.models.runs import JobStatus, ProbeSpec
1616
from dstack._internal.core.services.ssh.tunnel import (
1717
SSH_DEFAULT_OPTIONS,
1818
IPSocket,
@@ -21,6 +21,7 @@
2121
)
2222
from dstack._internal.server.db import get_db, get_session_ctx
2323
from dstack._internal.server.models import InstanceModel, JobModel, ProbeModel
24+
from dstack._internal.server.services.jobs import get_job_spec
2425
from dstack._internal.server.services.locking import get_locker
2526
from dstack._internal.server.services.logging import fmt
2627
from dstack._internal.server.services.ssh import container_ssh_tunnel
@@ -71,7 +72,7 @@ async def process_probes():
7172
if probe.job.status != JobStatus.RUNNING:
7273
probe.active = False
7374
else:
74-
job_spec: JobSpec = JobSpec.__response__.parse_raw(probe.job.job_spec_data)
75+
job_spec = get_job_spec(probe.job)
7576
probe_spec = job_spec.probes[probe.probe_num]
7677
if probe_spec.until_ready and probe.success_streak >= probe_spec.ready_after:
7778
probe.active = False
@@ -148,7 +149,7 @@ async def _get_service_replica_client(job: JobModel) -> AsyncGenerator[AsyncClie
148149
**SSH_DEFAULT_OPTIONS,
149150
"ConnectTimeout": str(int(SSH_CONNECT_TIMEOUT.total_seconds())),
150151
}
151-
job_spec: JobSpec = JobSpec.__response__.parse_raw(job.job_spec_data)
152+
job_spec = get_job_spec(job)
152153
with TemporaryDirectory() as temp_dir:
153154
app_socket_path = (Path(temp_dir) / "replica.sock").absolute()
154155
async with container_ssh_tunnel(

src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
JobTerminationReason,
3535
ProbeSpec,
3636
Run,
37-
RunSpec,
3837
RunStatus,
3938
)
4039
from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
@@ -67,6 +66,7 @@
6766
find_job,
6867
get_job_attached_volumes,
6968
get_job_runtime_data,
69+
get_job_spec,
7070
is_master_job,
7171
job_model_to_job_submission,
7272
switch_job_status,
@@ -82,6 +82,7 @@
8282
from dstack._internal.server.services.runner import client
8383
from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel
8484
from dstack._internal.server.services.runs import (
85+
get_run_spec,
8586
is_job_ready,
8687
run_model_to_run,
8788
)
@@ -734,7 +735,7 @@ def _process_provisioning_with_shim(
734735
Returns:
735736
is successful
736737
"""
737-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
738+
job_spec = get_job_spec(job_model)
738739

739740
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
740741

@@ -982,7 +983,7 @@ def _terminate_if_inactivity_duration_exceeded(
982983
job_model: JobModel,
983984
no_connections_secs: Optional[int],
984985
) -> None:
985-
conf = RunSpec.__response__.parse_raw(run_model.run_spec).configuration
986+
conf = get_run_spec(run_model).configuration
986987
if not isinstance(conf, DevEnvironmentConfiguration) or not isinstance(
987988
conf.inactivity_duration, int
988989
):

src/dstack/_internal/server/background/scheduled_tasks/runs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from dstack._internal.core.models.profiles import RetryEvent, StopCriteria
1414
from dstack._internal.core.models.runs import (
1515
Job,
16-
JobSpec,
1716
JobStatus,
1817
JobTerminationReason,
1918
Run,
@@ -33,6 +32,7 @@
3332
from dstack._internal.server.services import events
3433
from dstack._internal.server.services.jobs import (
3534
find_job,
35+
get_job_spec,
3636
get_job_specs_from_run_spec,
3737
group_jobs_by_replica_latest,
3838
is_master_job,
@@ -531,7 +531,7 @@ async def _handle_run_replicas(
531531
if job.status.is_finished():
532532
continue
533533
try:
534-
job_spec = JobSpec.__response__.parse_raw(job.job_spec_data)
534+
job_spec = get_job_spec(job)
535535
existing_group_names.add(job_spec.replica_group)
536536
except Exception:
537537
continue
@@ -647,7 +647,7 @@ async def _update_jobs_to_new_deployment_in_place(
647647
replica_group_name = None
648648

649649
if replicas:
650-
job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data)
650+
job_spec = get_job_spec(job_models[0])
651651
replica_group_name = job_spec.replica_group
652652

653653
# FIXME: Handle getting image configuration errors or skip it.
@@ -662,7 +662,7 @@ async def _update_jobs_to_new_deployment_in_place(
662662
)
663663
can_update_all_jobs = True
664664
for old_job_model, new_job_spec in zip(job_models, new_job_specs):
665-
old_job_spec = JobSpec.__response__.parse_raw(old_job_model.job_spec_data)
665+
old_job_spec = get_job_spec(old_job_model)
666666
if new_job_spec != old_job_spec:
667667
can_update_all_jobs = False
668668
break

src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from dstack._internal.server.services.jobs import (
4040
get_job_provisioning_data,
4141
get_job_runtime_data,
42+
get_job_spec,
4243
switch_job_status,
4344
)
4445
from dstack._internal.server.services.locking import get_locker
@@ -356,7 +357,7 @@ async def _detach_volumes_from_job_instance(
356357
instance_model: InstanceModel,
357358
volume_models: list[VolumeModel],
358359
) -> bool:
359-
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
360+
job_spec = get_job_spec(job_model)
360361
backend = await backends_services.get_project_backend_by_type(
361362
project=project,
362363
backend_type=jpd.backend,
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
from typing import Annotated
3+
4+
from fastapi import APIRouter, Depends
5+
from sqlalchemy.ext.asyncio import AsyncSession
6+
7+
from dstack._internal.core.errors import ResourceNotExistsError
8+
from dstack._internal.server.db import get_session
9+
from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse
10+
from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount
11+
from dstack._internal.server.services.sshproxy import get_upstream_response
12+
from dstack._internal.server.utils.routers import (
13+
CustomORJSONResponse,
14+
get_base_api_additional_responses,
15+
)
16+
17+
if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"):
18+
_auth = ServiceAccount(_token)
19+
else:
20+
_auth = AlwaysForbidden()
21+
22+
23+
router = APIRouter(
24+
prefix="/api/sshproxy",
25+
tags=["sshproxy"],
26+
responses=get_base_api_additional_responses(),
27+
dependencies=[Depends(_auth)],
28+
)
29+
30+
31+
@router.post("/get_upstream", response_model=GetUpstreamResponse)
32+
async def get_upstream(
33+
body: GetUpstreamRequest,
34+
session: Annotated[AsyncSession, Depends(get_session)],
35+
):
36+
response = await get_upstream_response(session=session, upstream_id=body.id)
37+
if response is None:
38+
raise ResourceNotExistsError()
39+
return CustomORJSONResponse(response)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Annotated
2+
3+
from pydantic import Field
4+
5+
from dstack._internal.core.models.common import CoreModel
6+
7+
8+
class GetUpstreamRequest(CoreModel):
9+
# The format of id is intentionally not limited to UUID to allow further extensions
10+
id: str
11+
12+
13+
class UpstreamHost(CoreModel):
14+
host: Annotated[str, Field(description="The hostname or IP address")]
15+
port: Annotated[int, Field(description="The SSH port")]
16+
user: Annotated[str, Field(description="The user to log in")]
17+
private_key: Annotated[str, Field(description="The private key in OpenSSH file format")]
18+
19+
20+
class GetUpstreamResponse(CoreModel):
21+
hosts: Annotated[
22+
list[UpstreamHost],
23+
Field(description="The chain of SSH hosts, the jump host(s) first, the target host last"),
24+
]
25+
authorized_keys: Annotated[
26+
list[str], Field(description="The list of authorized public keys in OpenSSH file format")
27+
]

0 commit comments

Comments
 (0)