Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions packages/prime-sandboxes/src/prime_sandboxes/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,12 +1311,18 @@ def create_ssh_session(
self,
sandbox_id: str,
ttl_seconds: Optional[int] = None,
public_key: Optional[str] = None,
) -> SSHSession:
"""Create an SSH session"""
self._guard_vm_unsupported(sandbox_id, "SSH")
"""Create an SSH session.

Pass ``public_key`` for VM sandboxes; container sandboxes authorize
keys separately after session creation.
"""
payload: Dict[str, Any] = {}
if ttl_seconds is not None:
payload["ttl_seconds"] = ttl_seconds
if public_key is not None:
payload["public_key"] = public_key
response = self.client.request(
"POST",
f"/sandbox/{sandbox_id}/ssh-session",
Expand All @@ -1325,8 +1331,7 @@ def create_ssh_session(
return SSHSession.model_validate(response)

def close_ssh_session(self, sandbox_id: str, session_id: str) -> None:
"""Close an SSH session and remove its exposure"""
self._guard_vm_unsupported(sandbox_id, "SSH")
"""Close an SSH session."""
self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}")


Expand Down Expand Up @@ -2233,12 +2238,18 @@ async def create_ssh_session(
self,
sandbox_id: str,
ttl_seconds: Optional[int] = None,
public_key: Optional[str] = None,
) -> SSHSession:
"""Create an SSH session"""
await self._guard_vm_unsupported(sandbox_id, "SSH")
"""Create an SSH session.

Pass ``public_key`` for VM sandboxes; container sandboxes authorize
keys separately after session creation.
"""
payload: Dict[str, Any] = {}
if ttl_seconds is not None:
payload["ttl_seconds"] = ttl_seconds
if public_key is not None:
payload["public_key"] = public_key
response = await self.client.request(
"POST",
f"/sandbox/{sandbox_id}/ssh-session",
Expand All @@ -2247,8 +2258,7 @@ async def create_ssh_session(
return SSHSession.model_validate(response)

async def close_ssh_session(self, sandbox_id: str, session_id: str) -> None:
"""Close an SSH session and remove its exposure"""
await self._guard_vm_unsupported(sandbox_id, "SSH")
"""Close an SSH session."""
await self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}")


Expand Down
71 changes: 51 additions & 20 deletions packages/prime-sandboxes/tests/test_vm_guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,35 @@ def test_sync_list_exposed_ports_blocked_for_vm():
assert recording.calls == []


def test_sync_create_ssh_session_blocked_for_vm():
def test_sync_create_ssh_session_allowed_for_vm():
client, recording = _make_sync_client(is_vm=True)
with pytest.raises(APIError) as exc_info:
client.create_ssh_session("sbx-vm")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
cast(Any, recording)._response = {
"session_id": "s",
"exposure_id": "s",
"sandbox_id": "sbx-vm",
"host": "h",
"port": 2222,
"external_endpoint": "h:2222",
"expires_at": datetime.now(timezone.utc).isoformat(),
"ttl_seconds": 300,
"gateway_url": "",
"user_ns": "",
"job_id": "sbx-vm",
"token": "",
}
client.create_ssh_session("sbx-vm", public_key="ssh-ed25519 AAAA...")
assert any(
method == "POST" and path.endswith("/ssh-session") for method, path, _ in recording.calls
)


def test_sync_close_ssh_session_blocked_for_vm():
def test_sync_close_ssh_session_allowed_for_vm():
client, recording = _make_sync_client(is_vm=True)
with pytest.raises(APIError) as exc_info:
client.close_ssh_session("sbx-vm", "sess-1")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
client.close_ssh_session("sbx-vm", "sess-1")
assert any(
method == "DELETE" and path.endswith("/ssh-session/sess-1")
for method, path, _ in recording.calls
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -327,25 +342,41 @@ async def test_async_list_exposed_ports_blocked_for_vm():


@pytest.mark.asyncio
async def test_async_create_ssh_session_blocked_for_vm():
async def test_async_create_ssh_session_allowed_for_vm():
client, recording = _make_async_client(is_vm=True)
cast(Any, recording)._response = {
"session_id": "s",
"exposure_id": "s",
"sandbox_id": "sbx-vm",
"host": "h",
"port": 2222,
"external_endpoint": "h:2222",
"expires_at": datetime.now(timezone.utc).isoformat(),
"ttl_seconds": 300,
"gateway_url": "",
"user_ns": "",
"job_id": "sbx-vm",
"token": "",
}
try:
with pytest.raises(APIError) as exc_info:
await client.create_ssh_session("sbx-vm")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
await client.create_ssh_session("sbx-vm", public_key="ssh-ed25519 AAAA...")
assert any(
method == "POST" and path.endswith("/ssh-session")
for method, path, _ in recording.calls
)
finally:
await client.aclose()


@pytest.mark.asyncio
async def test_async_close_ssh_session_blocked_for_vm():
async def test_async_close_ssh_session_allowed_for_vm():
client, recording = _make_async_client(is_vm=True)
try:
with pytest.raises(APIError) as exc_info:
await client.close_ssh_session("sbx-vm", "sess-1")
assert "SSH" in str(exc_info.value)
assert recording.calls == []
await client.close_ssh_session("sbx-vm", "sess-1")
assert any(
method == "DELETE" and path.endswith("/ssh-session/sess-1")
for method, path, _ in recording.calls
)
finally:
await client.aclose()

Expand Down
88 changes: 68 additions & 20 deletions packages/prime/src/prime_cli/commands/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
import string
import subprocess
import sys
import tempfile
import time
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -1367,7 +1368,7 @@ def cleanup() -> None:
with console.status("[bold blue]Checking sandbox status...", spinner="dots"):
sandbox = sandbox_client.get(sandbox_id)

_guard_vm_unsupported(sandbox, "SSH")
is_vm_sandbox = bool(getattr(sandbox, "vm", False))

if sandbox.status != "RUNNING":
console.print(f"[red]Error:[/red] Sandbox is not running (status: {sandbox.status})")
Expand All @@ -1387,29 +1388,33 @@ def cleanup() -> None:
with open(f"{key_path}.pub", "r") as f:
public_key = f.read().strip()

# Create SSH session
# VM sandboxes take the public key here; containers authorize it below.
console.print("[bold blue]Creating SSH session...[/bold blue]")
with console.status("[bold blue]Setting up SSH session...", spinner="dots"):
session = sandbox_client.create_ssh_session(sandbox_id)
if is_vm_sandbox:
session = sandbox_client.create_ssh_session(sandbox_id, public_key=public_key)
else:
session = sandbox_client.create_ssh_session(sandbox_id)
session_id = session.session_id

# Authorize the key
authorize_url = (
f"{session.gateway_url.rstrip('/')}/{session.user_ns}/{session.job_id}/authorize"
)
headers = {"Authorization": f"Bearer {session.token}"}
payload = {
"session_id": session.session_id,
"public_key": public_key,
"ttl_seconds": session.ttl_seconds,
}
try:
with httpx.Client(timeout=30) as client:
client.post(authorize_url, json=payload, headers=headers).raise_for_status()
except Exception as e:
console.print(f"[red]Error:[/red] Failed to authorize SSH key: {e}")
cleanup()
raise typer.Exit(1)
if not is_vm_sandbox:
# Containers authorize the key through the sidecar endpoint.
authorize_url = (
f"{session.gateway_url.rstrip('/')}/{session.user_ns}/{session.job_id}/authorize"
)
headers = {"Authorization": f"Bearer {session.token}"}
payload = {
"session_id": session.session_id,
"public_key": public_key,
"ttl_seconds": session.ttl_seconds,
}
try:
with httpx.Client(timeout=30) as client:
client.post(authorize_url, json=payload, headers=headers).raise_for_status()
except Exception as e:
console.print(f"[red]Error:[/red] Failed to authorize SSH key: {e}")
cleanup()
raise typer.Exit(1)

ssh_host = session.host
ssh_port = session.port
Expand All @@ -1434,6 +1439,49 @@ def cleanup() -> None:
ssh_cmd.extend(["-o", "UserKnownHostsFile=/dev/null"])
ssh_cmd.extend(["-o", "LogLevel=ERROR"])

# VM SSH connections need a session prefix before the handshake, so
# route them through a small ProxyCommand that writes it first.
if is_vm_sandbox:
prefix = f"PRIME-SSH-SESSION {session.session_id}\n"
python_exec = sys.executable or "python3"
# Use `read1` for responsive stdin forwarding and flush each
# chunk received from the remote side. Relay loops are plain
# while-loops (not list comprehensions) so forwarded bytes are
# not accumulated in memory for the lifetime of the session.
proxy_script = (
"import socket, sys, threading\n"
f"s = socket.create_connection(({ssh_host!r}, {int(ssh_port)}))\n"
f"s.sendall({prefix!r}.encode())\n"
"def _reader():\n"
" try:\n"
" while True:\n"
" b = s.recv(4096)\n"
" if not b:\n"
" break\n"
" sys.stdout.buffer.write(b)\n"
" sys.stdout.buffer.flush()\n"
" except OSError:\n"
" pass\n"
"t = threading.Thread(target=_reader, daemon=True)\n"
"t.start()\n"
"try:\n"
" while True:\n"
" b = sys.stdin.buffer.read1(4096)\n"
" if not b:\n"
" break\n"
" s.sendall(b)\n"
"except OSError:\n"
" pass\n"
"finally:\n"
" try:\n"
" s.shutdown(socket.SHUT_WR)\n"
" except OSError:\n"
" pass\n"
" s.close()\n"
)
proxy_cmd = f"{shlex.quote(python_exec)} -c {shlex.quote(proxy_script)}"
ssh_cmd.extend(["-o", f"ProxyCommand={proxy_cmd}"])

# Add identity file if specified
if key_path:
ssh_cmd.extend(["-i", key_path])
Expand Down
83 changes: 83 additions & 0 deletions packages/prime/tests/test_sandbox_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import shlex
import subprocess
from pathlib import Path
from types import SimpleNamespace
from typing import Any

Expand Down Expand Up @@ -347,3 +350,83 @@ def mock_bulk_delete(self: Any, **kwargs: Any) -> Any:
assert bulk_kwargs["user_id"] is None

assert "Processed 1 sandbox(es)" in output


def test_sandbox_ssh_vm_proxy_command_quotes_python_and_streams_without_list_comprehensions(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.setenv("PRIME_API_KEY", "dummy")
monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1")

temp_dir = tmp_path / "prime-ssh"
temp_dir.mkdir()
pub_key_path = temp_dir / "id_ed25519.pub"
pub_key_path.write_text("ssh-ed25519 AAAA-test-key user@test\n")

captured: dict[str, Any] = {"ssh_cmd": None, "public_key": None}

class FakeSandboxClient:
def __init__(self, _base_client: Any) -> None:
pass

def get(self, _sandbox_id: str) -> Any:
return SimpleNamespace(status="RUNNING", vm=True)

def create_ssh_session(
self, _sandbox_id: str, ttl_seconds: int | None = None, public_key: str | None = None
) -> Any:
captured["public_key"] = public_key
assert ttl_seconds is None
return SimpleNamespace(
session_id="sess-123",
host="vm-host.example",
port=2222,
gateway_url="https://gateway.example",
user_ns="ns",
job_id="job",
token="token",
ttl_seconds=300,
)

def close_ssh_session(self, _sandbox_id: str, _session_id: str) -> None:
return None

fake_python_exec = "/tmp/my python/bin/python3"

monkeypatch.setattr("prime_cli.commands.sandbox.APIClient", lambda: object())
monkeypatch.setattr("prime_cli.commands.sandbox.SandboxClient", FakeSandboxClient)
monkeypatch.setattr("prime_cli.commands.sandbox.tempfile.mkdtemp", lambda prefix: str(temp_dir))
monkeypatch.setattr("prime_cli.commands.sandbox.time.sleep", lambda _seconds: None)
monkeypatch.setattr("prime_cli.commands.sandbox.sys.executable", fake_python_exec)
monkeypatch.setattr("prime_cli.commands.sandbox.shutil.which", lambda _name: "/usr/bin/fake")

def _mock_subprocess_run(cmd: list[str], **kwargs: Any) -> subprocess.CompletedProcess[str]:
if cmd and cmd[0] == "ssh-keygen":
return subprocess.CompletedProcess(cmd, 0)
if cmd and cmd[0] == "ssh":
captured["ssh_cmd"] = cmd
return subprocess.CompletedProcess(cmd, 0)
raise AssertionError(f"Unexpected subprocess call: {cmd!r}, kwargs={kwargs!r}")

monkeypatch.setattr("prime_cli.commands.sandbox.subprocess.run", _mock_subprocess_run)

result = runner.invoke(app, ["sandbox", "ssh", "sbx-vm-1"])

output = strip_ansi(result.output)
assert result.exit_code == 0, f"Failed: {output}"
assert captured["public_key"] == "ssh-ed25519 AAAA-test-key user@test"

ssh_cmd = captured["ssh_cmd"]
assert isinstance(ssh_cmd, list)
proxy_opt = next((arg for arg in ssh_cmd if arg.startswith("ProxyCommand=")), None)
assert proxy_opt is not None

proxy_cmd = proxy_opt[len("ProxyCommand=") :]
assert proxy_cmd.startswith(f"{shlex.quote(fake_python_exec)} -c ")
parsed_proxy = shlex.split(proxy_cmd)
assert parsed_proxy[0] == fake_python_exec
assert parsed_proxy[1] == "-c"
proxy_script = parsed_proxy[2]

assert "while True:" in proxy_script
assert "for b in iter(" not in proxy_script
Loading