diff --git a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py index 2b2c6b76..317cacf2 100644 --- a/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py +++ b/packages/prime-sandboxes/src/prime_sandboxes/sandbox.py @@ -1311,12 +1311,18 @@ def create_ssh_session( self, sandbox_id: str, ttl_seconds: Optional[int] = None, + public_key: Optional[str] = None, ) -> SSHSession: - """Create an SSH session""" - self._guard_vm_unsupported(sandbox_id, "SSH") + """Create an SSH session. + + Pass ``public_key`` for VM sandboxes; container sandboxes authorize + keys separately after session creation. + """ payload: Dict[str, Any] = {} if ttl_seconds is not None: payload["ttl_seconds"] = ttl_seconds + if public_key is not None: + payload["public_key"] = public_key response = self.client.request( "POST", f"/sandbox/{sandbox_id}/ssh-session", @@ -1325,8 +1331,7 @@ def create_ssh_session( return SSHSession.model_validate(response) def close_ssh_session(self, sandbox_id: str, session_id: str) -> None: - """Close an SSH session and remove its exposure""" - self._guard_vm_unsupported(sandbox_id, "SSH") + """Close an SSH session.""" self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}") @@ -2233,12 +2238,18 @@ async def create_ssh_session( self, sandbox_id: str, ttl_seconds: Optional[int] = None, + public_key: Optional[str] = None, ) -> SSHSession: - """Create an SSH session""" - await self._guard_vm_unsupported(sandbox_id, "SSH") + """Create an SSH session. + + Pass ``public_key`` for VM sandboxes; container sandboxes authorize + keys separately after session creation. + """ payload: Dict[str, Any] = {} if ttl_seconds is not None: payload["ttl_seconds"] = ttl_seconds + if public_key is not None: + payload["public_key"] = public_key response = await self.client.request( "POST", f"/sandbox/{sandbox_id}/ssh-session", @@ -2247,8 +2258,7 @@ async def create_ssh_session( return SSHSession.model_validate(response) async def close_ssh_session(self, sandbox_id: str, session_id: str) -> None: - """Close an SSH session and remove its exposure""" - await self._guard_vm_unsupported(sandbox_id, "SSH") + """Close an SSH session.""" await self.client.request("DELETE", f"/sandbox/{sandbox_id}/ssh-session/{session_id}") diff --git a/packages/prime-sandboxes/tests/test_vm_guards.py b/packages/prime-sandboxes/tests/test_vm_guards.py index 8bb751e6..d183468b 100644 --- a/packages/prime-sandboxes/tests/test_vm_guards.py +++ b/packages/prime-sandboxes/tests/test_vm_guards.py @@ -119,20 +119,35 @@ def test_sync_list_exposed_ports_blocked_for_vm(): assert recording.calls == [] -def test_sync_create_ssh_session_blocked_for_vm(): +def test_sync_create_ssh_session_allowed_for_vm(): client, recording = _make_sync_client(is_vm=True) - with pytest.raises(APIError) as exc_info: - client.create_ssh_session("sbx-vm") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + cast(Any, recording)._response = { + "session_id": "s", + "exposure_id": "s", + "sandbox_id": "sbx-vm", + "host": "h", + "port": 2222, + "external_endpoint": "h:2222", + "expires_at": datetime.now(timezone.utc).isoformat(), + "ttl_seconds": 300, + "gateway_url": "", + "user_ns": "", + "job_id": "sbx-vm", + "token": "", + } + client.create_ssh_session("sbx-vm", public_key="ssh-ed25519 AAAA...") + assert any( + method == "POST" and path.endswith("/ssh-session") for method, path, _ in recording.calls + ) -def test_sync_close_ssh_session_blocked_for_vm(): +def test_sync_close_ssh_session_allowed_for_vm(): client, recording = _make_sync_client(is_vm=True) - with pytest.raises(APIError) as exc_info: - client.close_ssh_session("sbx-vm", "sess-1") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + client.close_ssh_session("sbx-vm", "sess-1") + assert any( + method == "DELETE" and path.endswith("/ssh-session/sess-1") + for method, path, _ in recording.calls + ) # --------------------------------------------------------------------------- @@ -327,25 +342,41 @@ async def test_async_list_exposed_ports_blocked_for_vm(): @pytest.mark.asyncio -async def test_async_create_ssh_session_blocked_for_vm(): +async def test_async_create_ssh_session_allowed_for_vm(): client, recording = _make_async_client(is_vm=True) + cast(Any, recording)._response = { + "session_id": "s", + "exposure_id": "s", + "sandbox_id": "sbx-vm", + "host": "h", + "port": 2222, + "external_endpoint": "h:2222", + "expires_at": datetime.now(timezone.utc).isoformat(), + "ttl_seconds": 300, + "gateway_url": "", + "user_ns": "", + "job_id": "sbx-vm", + "token": "", + } try: - with pytest.raises(APIError) as exc_info: - await client.create_ssh_session("sbx-vm") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + await client.create_ssh_session("sbx-vm", public_key="ssh-ed25519 AAAA...") + assert any( + method == "POST" and path.endswith("/ssh-session") + for method, path, _ in recording.calls + ) finally: await client.aclose() @pytest.mark.asyncio -async def test_async_close_ssh_session_blocked_for_vm(): +async def test_async_close_ssh_session_allowed_for_vm(): client, recording = _make_async_client(is_vm=True) try: - with pytest.raises(APIError) as exc_info: - await client.close_ssh_session("sbx-vm", "sess-1") - assert "SSH" in str(exc_info.value) - assert recording.calls == [] + await client.close_ssh_session("sbx-vm", "sess-1") + assert any( + method == "DELETE" and path.endswith("/ssh-session/sess-1") + for method, path, _ in recording.calls + ) finally: await client.aclose() diff --git a/packages/prime/src/prime_cli/commands/sandbox.py b/packages/prime/src/prime_cli/commands/sandbox.py index b766119f..28b122da 100644 --- a/packages/prime/src/prime_cli/commands/sandbox.py +++ b/packages/prime/src/prime_cli/commands/sandbox.py @@ -5,6 +5,7 @@ import shutil import string import subprocess +import sys import tempfile import time from typing import Any, Dict, List, Optional @@ -1367,7 +1368,7 @@ def cleanup() -> None: with console.status("[bold blue]Checking sandbox status...", spinner="dots"): sandbox = sandbox_client.get(sandbox_id) - _guard_vm_unsupported(sandbox, "SSH") + is_vm_sandbox = bool(getattr(sandbox, "vm", False)) if sandbox.status != "RUNNING": console.print(f"[red]Error:[/red] Sandbox is not running (status: {sandbox.status})") @@ -1387,29 +1388,33 @@ def cleanup() -> None: with open(f"{key_path}.pub", "r") as f: public_key = f.read().strip() - # Create SSH session + # VM sandboxes take the public key here; containers authorize it below. console.print("[bold blue]Creating SSH session...[/bold blue]") with console.status("[bold blue]Setting up SSH session...", spinner="dots"): - session = sandbox_client.create_ssh_session(sandbox_id) + if is_vm_sandbox: + session = sandbox_client.create_ssh_session(sandbox_id, public_key=public_key) + else: + session = sandbox_client.create_ssh_session(sandbox_id) session_id = session.session_id - # Authorize the key - authorize_url = ( - f"{session.gateway_url.rstrip('/')}/{session.user_ns}/{session.job_id}/authorize" - ) - headers = {"Authorization": f"Bearer {session.token}"} - payload = { - "session_id": session.session_id, - "public_key": public_key, - "ttl_seconds": session.ttl_seconds, - } - try: - with httpx.Client(timeout=30) as client: - client.post(authorize_url, json=payload, headers=headers).raise_for_status() - except Exception as e: - console.print(f"[red]Error:[/red] Failed to authorize SSH key: {e}") - cleanup() - raise typer.Exit(1) + if not is_vm_sandbox: + # Containers authorize the key through the sidecar endpoint. + authorize_url = ( + f"{session.gateway_url.rstrip('/')}/{session.user_ns}/{session.job_id}/authorize" + ) + headers = {"Authorization": f"Bearer {session.token}"} + payload = { + "session_id": session.session_id, + "public_key": public_key, + "ttl_seconds": session.ttl_seconds, + } + try: + with httpx.Client(timeout=30) as client: + client.post(authorize_url, json=payload, headers=headers).raise_for_status() + except Exception as e: + console.print(f"[red]Error:[/red] Failed to authorize SSH key: {e}") + cleanup() + raise typer.Exit(1) ssh_host = session.host ssh_port = session.port @@ -1434,6 +1439,49 @@ def cleanup() -> None: ssh_cmd.extend(["-o", "UserKnownHostsFile=/dev/null"]) ssh_cmd.extend(["-o", "LogLevel=ERROR"]) + # VM SSH connections need a session prefix before the handshake, so + # route them through a small ProxyCommand that writes it first. + if is_vm_sandbox: + prefix = f"PRIME-SSH-SESSION {session.session_id}\n" + python_exec = sys.executable or "python3" + # Use `read1` for responsive stdin forwarding and flush each + # chunk received from the remote side. Relay loops are plain + # while-loops (not list comprehensions) so forwarded bytes are + # not accumulated in memory for the lifetime of the session. + proxy_script = ( + "import socket, sys, threading\n" + f"s = socket.create_connection(({ssh_host!r}, {int(ssh_port)}))\n" + f"s.sendall({prefix!r}.encode())\n" + "def _reader():\n" + " try:\n" + " while True:\n" + " b = s.recv(4096)\n" + " if not b:\n" + " break\n" + " sys.stdout.buffer.write(b)\n" + " sys.stdout.buffer.flush()\n" + " except OSError:\n" + " pass\n" + "t = threading.Thread(target=_reader, daemon=True)\n" + "t.start()\n" + "try:\n" + " while True:\n" + " b = sys.stdin.buffer.read1(4096)\n" + " if not b:\n" + " break\n" + " s.sendall(b)\n" + "except OSError:\n" + " pass\n" + "finally:\n" + " try:\n" + " s.shutdown(socket.SHUT_WR)\n" + " except OSError:\n" + " pass\n" + " s.close()\n" + ) + proxy_cmd = f"{shlex.quote(python_exec)} -c {shlex.quote(proxy_script)}" + ssh_cmd.extend(["-o", f"ProxyCommand={proxy_cmd}"]) + # Add identity file if specified if key_path: ssh_cmd.extend(["-i", key_path]) diff --git a/packages/prime/tests/test_sandbox_cli.py b/packages/prime/tests/test_sandbox_cli.py index 11b94845..76215598 100644 --- a/packages/prime/tests/test_sandbox_cli.py +++ b/packages/prime/tests/test_sandbox_cli.py @@ -1,3 +1,6 @@ +import shlex +import subprocess +from pathlib import Path from types import SimpleNamespace from typing import Any @@ -347,3 +350,83 @@ def mock_bulk_delete(self: Any, **kwargs: Any) -> Any: assert bulk_kwargs["user_id"] is None assert "Processed 1 sandbox(es)" in output + + +def test_sandbox_ssh_vm_proxy_command_quotes_python_and_streams_without_list_comprehensions( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + monkeypatch.setenv("PRIME_API_KEY", "dummy") + monkeypatch.setenv("PRIME_DISABLE_VERSION_CHECK", "1") + + temp_dir = tmp_path / "prime-ssh" + temp_dir.mkdir() + pub_key_path = temp_dir / "id_ed25519.pub" + pub_key_path.write_text("ssh-ed25519 AAAA-test-key user@test\n") + + captured: dict[str, Any] = {"ssh_cmd": None, "public_key": None} + + class FakeSandboxClient: + def __init__(self, _base_client: Any) -> None: + pass + + def get(self, _sandbox_id: str) -> Any: + return SimpleNamespace(status="RUNNING", vm=True) + + def create_ssh_session( + self, _sandbox_id: str, ttl_seconds: int | None = None, public_key: str | None = None + ) -> Any: + captured["public_key"] = public_key + assert ttl_seconds is None + return SimpleNamespace( + session_id="sess-123", + host="vm-host.example", + port=2222, + gateway_url="https://gateway.example", + user_ns="ns", + job_id="job", + token="token", + ttl_seconds=300, + ) + + def close_ssh_session(self, _sandbox_id: str, _session_id: str) -> None: + return None + + fake_python_exec = "/tmp/my python/bin/python3" + + monkeypatch.setattr("prime_cli.commands.sandbox.APIClient", lambda: object()) + monkeypatch.setattr("prime_cli.commands.sandbox.SandboxClient", FakeSandboxClient) + monkeypatch.setattr("prime_cli.commands.sandbox.tempfile.mkdtemp", lambda prefix: str(temp_dir)) + monkeypatch.setattr("prime_cli.commands.sandbox.time.sleep", lambda _seconds: None) + monkeypatch.setattr("prime_cli.commands.sandbox.sys.executable", fake_python_exec) + monkeypatch.setattr("prime_cli.commands.sandbox.shutil.which", lambda _name: "/usr/bin/fake") + + def _mock_subprocess_run(cmd: list[str], **kwargs: Any) -> subprocess.CompletedProcess[str]: + if cmd and cmd[0] == "ssh-keygen": + return subprocess.CompletedProcess(cmd, 0) + if cmd and cmd[0] == "ssh": + captured["ssh_cmd"] = cmd + return subprocess.CompletedProcess(cmd, 0) + raise AssertionError(f"Unexpected subprocess call: {cmd!r}, kwargs={kwargs!r}") + + monkeypatch.setattr("prime_cli.commands.sandbox.subprocess.run", _mock_subprocess_run) + + result = runner.invoke(app, ["sandbox", "ssh", "sbx-vm-1"]) + + output = strip_ansi(result.output) + assert result.exit_code == 0, f"Failed: {output}" + assert captured["public_key"] == "ssh-ed25519 AAAA-test-key user@test" + + ssh_cmd = captured["ssh_cmd"] + assert isinstance(ssh_cmd, list) + proxy_opt = next((arg for arg in ssh_cmd if arg.startswith("ProxyCommand=")), None) + assert proxy_opt is not None + + proxy_cmd = proxy_opt[len("ProxyCommand=") :] + assert proxy_cmd.startswith(f"{shlex.quote(fake_python_exec)} -c ") + parsed_proxy = shlex.split(proxy_cmd) + assert parsed_proxy[0] == fake_python_exec + assert parsed_proxy[1] == "-c" + proxy_script = parsed_proxy[2] + + assert "while True:" in proxy_script + assert "for b in iter(" not in proxy_script