diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index fbe972c2b..fd5d5506c 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -6,10 +6,10 @@ # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -import asyncio import os import sys +import anyio from jumpstarter_kubernetes.controller import get_latest_compatible_controller_version os.environ["TERM"] = "dumb" @@ -64,7 +64,9 @@ def get_controller_version(): else: version = None - return asyncio.run(get_latest_compatible_controller_version(client_version=version)) + async def _run(): + return await get_latest_compatible_controller_version(client_version=version) + return anyio.run(_run) def get_index_url(): diff --git a/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/blocking.py b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/blocking.py index fbf02a19a..bc0cb5d1e 100644 --- a/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/blocking.py +++ b/python/packages/jumpstarter-cli-common/jumpstarter_cli_common/blocking.py @@ -1,10 +1,13 @@ -from asyncio import run from functools import wraps +import anyio + def blocking(f): @wraps(f) def wrapper(*args, **kwargs): - return run(f(*args, **kwargs)) + async def _run(): + return await f(*args, **kwargs) + return anyio.run(_run) return wrapper diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/login_test.py b/python/packages/jumpstarter-cli/jumpstarter_cli/login_test.py index 98295ce9f..87b61bfed 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/login_test.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/login_test.py @@ -1,8 +1,8 @@ -import asyncio import json import ssl from unittest.mock import AsyncMock, MagicMock, patch +import anyio import click import pytest from click.testing import CliRunner @@ -84,7 +84,9 @@ def get(self, *args, **kwargs): monkeypatch.setattr("jumpstarter_cli.login.aiohttp.ClientSession", FakeClientSession) with pytest.raises(click.ClickException, match="Timed out while connecting"): - asyncio.run(fetch_auth_config("login.example.com")) + async def _run(): + return await fetch_auth_config("login.example.com") + anyio.run(_run) def test_fetch_auth_config_maps_json_decode_error(monkeypatch) -> None: @@ -116,7 +118,9 @@ def get(self, *args, **kwargs): monkeypatch.setattr("jumpstarter_cli.login.aiohttp.ClientSession", FakeClientSession) with pytest.raises(click.ClickException, match="Invalid JSON response received"): - asyncio.run(fetch_auth_config("login.example.com")) + async def _run(): + return await fetch_auth_config("login.example.com") + anyio.run(_run) def test_login_cli_shows_timeout_message(monkeypatch) -> None: @@ -151,13 +155,13 @@ async def fake_fetch_auth_config(*args, **kwargs): assert "TLS certificate verification failed" in result.output -@pytest.mark.asyncio +@pytest.mark.anyio async def test_fetch_auth_config_rejects_http_without_insecure_tls(): with pytest.raises(click.UsageError, match="--insecure-tls"): await fetch_auth_config("http://login.example.com", insecure_tls=False) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_fetch_auth_config_allows_explicit_http_with_insecure_tls(): mock_response = MagicMock() mock_response.status = 200 @@ -183,7 +187,7 @@ async def test_fetch_auth_config_allows_explicit_http_with_insecure_tls(): assert result["grpcEndpoint"] == "grpc.example.com" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_fetch_auth_config_defaults_to_https_with_insecure_tls(): mock_response = MagicMock() mock_response.status = 200 diff --git a/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/driver.py b/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/driver.py index 4eefff6c4..c3339913d 100644 --- a/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/driver.py +++ b/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/driver.py @@ -1,5 +1,3 @@ -import asyncio -import asyncio.subprocess import ipaddress import shutil import socket @@ -8,8 +6,14 @@ from collections.abc import AsyncGenerator from dataclasses import dataclass, field from pathlib import Path +from subprocess import PIPE from typing import Literal, TypedDict +import anyio +from anyio import IncompleteRead +from anyio.abc import Process +from anyio.streams.buffered import BufferedByteReceiveStream + from . import dnsmasq, iproute, nftables from .ntp_server import NtpServer from jumpstarter.driver import Driver, export @@ -67,7 +71,7 @@ class DutNetwork(Driver): _added_aliases: set[str] = field(init=False, default_factory=set) _fwd_rule_handles: list[int] = field(init=False, default_factory=list) _ntp_server: NtpServer | None = field(init=False, default=None) - _tcpdump_process: asyncio.subprocess.Process | None = field(init=False, default=None) + _tcpdump_process: Process | None = field(init=False, default=None) @classmethod def client(cls) -> str: @@ -467,21 +471,18 @@ async def tcpdump(self, args: list[str] | None = None) -> AsyncGenerator[str, No self.logger.info("Starting tcpdump: %s", " ".join(cmd)) - proc = await asyncio.subprocess.create_subprocess_exec( - cmd[0], - *cmd[1:], - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.STDOUT, - ) + proc = await anyio.open_process(cmd, stdout=PIPE, stderr=PIPE) self._tcpdump_process = proc try: assert proc.stdout is not None + buffered = BufferedByteReceiveStream(proc.stdout) while True: - line = await proc.stdout.readline() - if not line: + try: + line = await buffered.receive_until(b"\n", 1048576) + except (anyio.EndOfStream, anyio.ClosedResourceError, IncompleteRead): break - yield line.decode("utf-8", errors="replace").rstrip("\n") + yield line.decode("utf-8", errors="replace") finally: if proc.returncode is None: try: diff --git a/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/test_tcpdump.py b/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/test_tcpdump.py index ba4909ce2..d0d632a74 100644 --- a/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/test_tcpdump.py +++ b/python/packages/jumpstarter-driver-dut-network/jumpstarter_driver_dut_network/test_tcpdump.py @@ -4,10 +4,11 @@ and the streaming driver method using mocked subprocesses. """ -import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch +import anyio +import anyio.abc import pytest from .driver import DutNetwork @@ -161,44 +162,50 @@ def test_multiple_blocked_flags(self): assert DutNetwork._sanitize_tcpdump_args(args) == ["-c", "5"] +def _create_mock_byte_stream(data: bytes): + """Create a mock ByteReceiveStream that returns data then raises EndOfStream.""" + stream = MagicMock(spec=anyio.abc.ByteReceiveStream) + call_count = {"n": 0} + + async def mock_receive(max_bytes=65536): + if call_count["n"] >= 1: + raise anyio.EndOfStream() + call_count["n"] += 1 + return data + + stream.receive = mock_receive + stream.aclose = AsyncMock() + return stream + + class TestTcpdumpMethod: def test_tcpdump_raises_when_disabled(self, tmp_path: Path): driver = _make_driver(tmp_path, enable_tcpdump=False) with pytest.raises(RuntimeError, match="tcpdump is not enabled"): - asyncio.run( - _consume_async_gen(driver.tcpdump()) + anyio.run( + _consume_async_gen, driver.tcpdump() ) def test_tcpdump_streams_output(self, tmp_path: Path): driver = _make_driver(tmp_path, enable_tcpdump=True) - mock_stdout = AsyncMock() - lines = [ - b"12:00:00.000000 IP 192.168.100.10 > 8.8.8.8: ICMP echo request\n", - b"12:00:00.001000 IP 8.8.8.8 > 192.168.100.10: ICMP echo reply\n", - b"", # EOF - ] - state = {"call_count": 0} - - async def mock_readline(): - if state["call_count"] < len(lines): - result = lines[state["call_count"]] - state["call_count"] += 1 - return result - return b"" - - mock_stdout.readline = mock_readline + data = ( + b"12:00:00.000000 IP 192.168.100.10 > 8.8.8.8: ICMP echo request\n" + b"12:00:00.001000 IP 8.8.8.8 > 192.168.100.10: ICMP echo reply\n" + ) + mock_stdout = _create_mock_byte_stream(data) mock_proc = AsyncMock() mock_proc.stdout = mock_stdout + mock_proc.stderr = _create_mock_byte_stream(b"") mock_proc.returncode = None mock_proc.terminate = MagicMock() mock_proc.wait = AsyncMock() - with patch(f"{_DRIVER_MODULE}.asyncio.subprocess.create_subprocess_exec", - return_value=mock_proc): - output = asyncio.run( - _consume_async_gen(driver.tcpdump()) + with patch(f"{_DRIVER_MODULE}.anyio.open_process", + new_callable=AsyncMock, return_value=mock_proc): + output = anyio.run( + _consume_async_gen, driver.tcpdump() ) assert len(output) == 2 @@ -208,82 +215,55 @@ async def mock_readline(): def test_tcpdump_enforces_interface(self, tmp_path: Path): driver = _make_driver(tmp_path, enable_tcpdump=True) - mock_stdout = AsyncMock() - mock_stdout.readline = AsyncMock(return_value=b"") + mock_stdout = _create_mock_byte_stream(b"") mock_proc = AsyncMock() mock_proc.stdout = mock_stdout + mock_proc.stderr = _create_mock_byte_stream(b"") mock_proc.returncode = 0 mock_proc.terminate = MagicMock() mock_proc.wait = AsyncMock() - with patch(f"{_DRIVER_MODULE}.asyncio.subprocess.create_subprocess_exec", - return_value=mock_proc) as mock_exec: - asyncio.run( - _consume_async_gen(driver.tcpdump(args=["-i", "evil-iface", "-c", "1"])) + with patch(f"{_DRIVER_MODULE}.anyio.open_process", + new_callable=AsyncMock, return_value=mock_proc) as mock_exec: + anyio.run( + _consume_async_gen, driver.tcpdump(args=["-i", "evil-iface", "-c", "1"]) ) - # Verify the command was called with the correct interface call_args = mock_exec.call_args[0] - cmd = list(call_args) + cmd = list(call_args[0]) assert cmd[0] == "tcpdump" assert "-i" in cmd iface_idx = cmd.index("-i") assert cmd[iface_idx + 1] == "eth-dut" - # The user-specified -i should have been removed by sanitization assert cmd.count("-i") == 1 def test_tcpdump_passes_extra_args(self, tmp_path: Path): driver = _make_driver(tmp_path, enable_tcpdump=True) - mock_stdout = AsyncMock() - mock_stdout.readline = AsyncMock(return_value=b"") + mock_stdout = _create_mock_byte_stream(b"") mock_proc = AsyncMock() mock_proc.stdout = mock_stdout + mock_proc.stderr = _create_mock_byte_stream(b"") mock_proc.returncode = 0 mock_proc.terminate = MagicMock() mock_proc.wait = AsyncMock() - with patch(f"{_DRIVER_MODULE}.asyncio.subprocess.create_subprocess_exec", - return_value=mock_proc) as mock_exec: - asyncio.run( - _consume_async_gen(driver.tcpdump(args=["-c", "10", "-n", "port", "80"])) + with patch(f"{_DRIVER_MODULE}.anyio.open_process", + new_callable=AsyncMock, return_value=mock_proc) as mock_exec: + anyio.run( + _consume_async_gen, driver.tcpdump(args=["-c", "10", "-n", "port", "80"]) ) call_args = mock_exec.call_args[0] - cmd = list(call_args) + cmd = list(call_args[0]) assert "-c" in cmd assert "10" in cmd assert "-n" in cmd assert "port" in cmd assert "80" in cmd - def test_tcpdump_cleanup_on_cancel(self, tmp_path: Path): - driver = _make_driver(tmp_path, enable_tcpdump=True) - - mock_stdout = AsyncMock() - # Simulate a stream that never ends - mock_stdout.readline = AsyncMock( - side_effect=[b"line 1\n", b"line 2\n", asyncio.CancelledError()] - ) - - mock_proc = AsyncMock() - mock_proc.stdout = mock_stdout - mock_proc.returncode = None - mock_proc.terminate = MagicMock() - mock_proc.wait = AsyncMock() - - with patch(f"{_DRIVER_MODULE}.asyncio.subprocess.create_subprocess_exec", - return_value=mock_proc): - with pytest.raises(asyncio.CancelledError): - asyncio.run( - _consume_async_gen(driver.tcpdump()) - ) - - # Verify the process was terminated - mock_proc.terminate.assert_called_once() - class TestTcpdumpCleanup: def test_cleanup_terminates_tcpdump_process(self, tmp_path: Path): diff --git a/python/packages/jumpstarter-driver-mitmproxy/examples/addons/data_stream_websocket.py b/python/packages/jumpstarter-driver-mitmproxy/examples/addons/data_stream_websocket.py index 289007ce0..0e385034a 100644 --- a/python/packages/jumpstarter-driver-mitmproxy/examples/addons/data_stream_websocket.py +++ b/python/packages/jumpstarter-driver-mitmproxy/examples/addons/data_stream_websocket.py @@ -44,12 +44,14 @@ from __future__ import annotations -import asyncio import json import math import random import time +import anyio +import anyio.abc +from anyio import CancelScope from mitmproxy import ctx, http @@ -76,7 +78,29 @@ class Handler: """ def __init__(self): - self._tasks: dict[int, asyncio.Task] = {} + self._cancel_scopes: dict[int, CancelScope] = {} + self._task_group: anyio.abc.TaskGroup | None = None + + async def _ensure_task_group(self): + # The mitmproxy addon lifecycle (websocket_message/done) does not + # support wrapping the entire handler in an async with block, so + # we manage __aenter__/__aexit__ manually here. This is a known + # deviation from anyio structured concurrency conventions. + if self._task_group is None: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + + async def done(self): + if self._task_group is not None: + for scope in self._cancel_scopes.values(): + scope.cancel() + self._cancel_scopes.clear() + try: + await self._task_group.__aexit__(None, None, None) + except BaseExceptionGroup as eg: + for exc in eg.exceptions: + ctx.log.error(f"Task group exception during shutdown: {exc}") + self._task_group = None def handle(self, flow: http.HTTPFlow, config: dict) -> bool: """Handle the initial WebSocket upgrade request. @@ -86,14 +110,12 @@ def handle(self, flow: http.HTTPFlow, config: dict) -> bool: injector take over. Returns True to indicate the request was handled (but we - don't set flow.response — we let the WebSocket handshake + don't set flow.response -- we let the WebSocket handshake complete naturally by NOT intercepting it here). """ - # Don't block the handshake — return False to let it through - # to the server (or get intercepted later by websocket hooks) return False - def websocket_message(self, flow: http.HTTPFlow, config: dict): + async def websocket_message(self, flow: http.HTTPFlow, config: dict): """Handle WebSocket messages and start telemetry injection. On the first client message (typically a subscribe/init @@ -104,11 +126,9 @@ def websocket_message(self, flow: http.HTTPFlow, config: dict): last_msg = flow.websocket.messages[-1] - # Only react to client messages if not last_msg.from_client: return - # Parse client command try: cmd = json.loads(last_msg.text) if last_msg.is_text else {} except (json.JSONDecodeError, UnicodeDecodeError): @@ -119,8 +139,7 @@ def websocket_message(self, flow: http.HTTPFlow, config: dict): msg_type = cmd.get("type", cmd.get("action", "subscribe")) if msg_type in ("subscribe", "start", "init"): - # Start pushing telemetry if not already running - if flow_id not in self._tasks or self._tasks[flow_id].done(): + if flow_id not in self._cancel_scopes or self._cancel_scopes[flow_id].cancel_called: scenario_name = cmd.get( "scenario", config.get("scenario", "normal"), @@ -131,14 +150,15 @@ def websocket_message(self, flow: http.HTTPFlow, config: dict): "normal", DEFAULT_SCENARIOS["normal"], )) - task = asyncio.ensure_future( - self._push_telemetry( - flow, scenario, interval_ms / 1000.0, - ) + scope = CancelScope() + self._cancel_scopes[flow_id] = scope + + await self._ensure_task_group() + self._task_group.start_soon( + self._push_telemetry_with_scope, + flow, scenario, interval_ms / 1000.0, scope, ) - self._tasks[flow_id] = task - # Send acknowledgment ack = json.dumps({ "type": "subscribed", "scenario": scenario_name, @@ -153,55 +173,57 @@ def websocket_message(self, flow: http.HTTPFlow, config: dict): ) elif msg_type in ("unsubscribe", "stop"): - if flow_id in self._tasks: - self._tasks[flow_id].cancel() - del self._tasks[flow_id] + if flow_id in self._cancel_scopes: + self._cancel_scopes[flow_id].cancel() + del self._cancel_scopes[flow_id] ctx.log.info("WS telemetry stopped") elif msg_type == "set_scenario": - # Switch scenario mid-stream new_scenario = cmd.get("scenario", "normal") - if flow_id in self._tasks: - self._tasks[flow_id].cancel() - del self._tasks[flow_id] + if flow_id in self._cancel_scopes: + self._cancel_scopes[flow_id].cancel() + del self._cancel_scopes[flow_id] scenarios = config.get("scenarios", DEFAULT_SCENARIOS) scenario = scenarios.get(new_scenario, DEFAULT_SCENARIOS.get( new_scenario, DEFAULT_SCENARIOS["normal"], )) interval_ms = config.get("push_interval_ms", 100) - task = asyncio.ensure_future( - self._push_telemetry( - flow, scenario, interval_ms / 1000.0, - ) + + scope = CancelScope() + self._cancel_scopes[flow_id] = scope + + await self._ensure_task_group() + self._task_group.start_soon( + self._push_telemetry_with_scope, + flow, scenario, interval_ms / 1000.0, scope, ) - self._tasks[flow_id] = task ctx.log.info(f"WS telemetry scenario changed: {new_scenario}") - async def _push_telemetry( + async def _push_telemetry_with_scope( self, flow: http.HTTPFlow, scenario: dict, interval_s: float, + scope: CancelScope, ): """Async loop that pushes telemetry frames to the client.""" state = SensorState(scenario) try: - while ( - flow.websocket is not None - and flow.websocket.timestamp_end is None - ): - frame = state.next_frame() - payload = json.dumps(frame).encode() - - ctx.master.commands.call( - "inject.websocket", flow, True, payload, - ) + with scope: + while ( + flow.websocket is not None + and flow.websocket.timestamp_end is None + ): + frame = state.next_frame() + payload = json.dumps(frame).encode() + + ctx.master.commands.call( + "inject.websocket", flow, True, payload, + ) - await asyncio.sleep(interval_s) + await anyio.sleep(interval_s) - except asyncio.CancelledError: - ctx.log.debug("Telemetry push task cancelled") except Exception as e: ctx.log.error(f"Telemetry push error: {e}") @@ -222,7 +244,6 @@ def __init__(self, scenario: dict): self.t0 = time.time() self.frame_num = 0 - # Initial state value_range = scenario.get("value_range", [30, 70]) self.value = (value_range[0] + value_range[1]) / 2 self.rate = scenario.get("rate_range", [100, 500])[0] @@ -235,19 +256,17 @@ def __init__(self, scenario: dict): def next_frame(self) -> dict: """Generate the next telemetry frame.""" - dt = 0.1 # ~100ms per frame + dt = 0.1 elapsed = time.time() - self.t0 self.frame_num += 1 - # Value: sinusoidal oscillation within range value_range = self.scenario.get("value_range", [30, 70]) value_mid = (value_range[0] + value_range[1]) / 2 value_amp = (value_range[1] - value_range[0]) / 2 self.value = value_mid + value_amp * math.sin(elapsed * 0.3) - self.value += random.gauss(0, 0.5) # Jitter + self.value += random.gauss(0, 0.5) self.value = max(value_range[0], min(value_range[1], self.value)) - # Rate: correlates with value, with random variation rate_range = self.scenario.get("rate_range", [100, 500]) if value_range[1] > value_range[0]: rate_ratio = (self.value - value_range[0]) / ( @@ -259,19 +278,15 @@ def next_frame(self) -> dict: self.rate += random.gauss(0, 5) self.rate = max(0, self.rate) - # Battery: drain (or recover) over time drain_rate = self.scenario.get("drain_pct_per_s", 0.015) self.battery_pct -= drain_rate * dt self.battery_pct = max(0, min(100, self.battery_pct)) - # Counter: accumulate based on value self.counter += self.value * dt - # Temperature: exponential rise toward steady state target_temp = 45.0 if self.value > 0 else 25.0 self.temperature += (target_temp - self.temperature) * 0.01 - # GPS: drift along heading speed_ms = self.value * 0.1 self.gps_lat += ( math.cos(math.radians(self.heading)) @@ -282,11 +297,9 @@ def next_frame(self) -> dict: * speed_ms * dt / max(111320 * math.cos(math.radians(self.gps_lat)), 1) ) - # Gentle heading wander self.heading += random.gauss(0, 0.2) self.heading %= 360 - # State selection based on value if self.value < 1: state = "idle" elif self.value < 30: @@ -296,7 +309,6 @@ def next_frame(self) -> dict: else: state = "high" - # Voltage: correlates with battery voltage = 3.0 + (self.battery_pct / 100) * 1.2 return { @@ -319,7 +331,6 @@ def next_frame(self) -> dict: } -# Default scenario definitions (used if not in config) DEFAULT_SCENARIOS = { "idle": { "value_range": [0, 0], diff --git a/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/bundled_addon.py b/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/bundled_addon.py index c958c24ad..edf54a7cd 100644 --- a/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/bundled_addon.py +++ b/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/bundled_addon.py @@ -22,7 +22,6 @@ from __future__ import annotations -import asyncio import hashlib import importlib import importlib.util @@ -38,6 +37,7 @@ from typing import Any from urllib.parse import parse_qs, urlparse +import anyio from mitmproxy import ctx, http # ── Helpers ────────────────────────────────────────────────── @@ -831,7 +831,7 @@ async def _send_response(self, flow: http.HTTPFlow, endpoint: dict): self.config.get("default_latency_ms", 0), ) if latency_ms > 0: - await asyncio.sleep(latency_ms / 1000.0) + await anyio.sleep(latency_ms / 1000.0) # Build response headers resp_headers = {"Content-Type": content_type} diff --git a/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/driver.py b/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/driver.py index ad6f59075..cda4ca74d 100644 --- a/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/driver.py +++ b/python/packages/jumpstarter-driver-mitmproxy/jumpstarter_driver_mitmproxy/driver.py @@ -28,7 +28,6 @@ from __future__ import annotations -import asyncio import base64 import fnmatch import json @@ -47,6 +46,7 @@ from pathlib import Path from urllib.parse import urlparse +import anyio import yaml from pydantic import BaseModel, model_validator @@ -1480,7 +1480,7 @@ async def watch_captured_requests(self) -> AsyncGenerator[str, None]: yield json.dumps(req) while True: - await asyncio.sleep(0.3) + await anyio.sleep(0.3) with self._capture_lock: new_count = len(self._captured_requests) if new_count > last_index: diff --git a/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/conftest.py b/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/conftest.py index f8410c472..711840aa5 100644 --- a/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/conftest.py +++ b/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/conftest.py @@ -13,6 +13,11 @@ async def echo_handler(stream): pass +@pytest.fixture +def anyio_backend(): + return "asyncio" + + @pytest.fixture def tcp_echo_server(): with start_blocking_portal() as portal: diff --git a/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py b/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py index d80ed25d0..8a5dc47c3 100644 --- a/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py +++ b/python/packages/jumpstarter-driver-network/jumpstarter_driver_network/driver_test.py @@ -169,7 +169,7 @@ def test_dbus_network_session(monkeypatch): assert oldvar == os.getenv("DBUS_SESSION_BUS_ADDRESS") -@pytest.mark.asyncio +@pytest.mark.anyio async def test_websocket_network_connect(): ws = AsyncMock() ws.__aenter__.return_value = ws diff --git a/python/packages/jumpstarter-driver-pi-pico/jumpstarter_driver_pi_pico/driver_test.py b/python/packages/jumpstarter-driver-pi-pico/jumpstarter_driver_pi_pico/driver_test.py index 20a03e452..5fc43d61e 100644 --- a/python/packages/jumpstarter-driver-pi-pico/jumpstarter_driver_pi_pico/driver_test.py +++ b/python/packages/jumpstarter-driver-pi-pico/jumpstarter_driver_pi_pico/driver_test.py @@ -1,7 +1,7 @@ -import asyncio from dataclasses import dataclass, field from unittest.mock import MagicMock +import anyio import pytest from jumpstarter_driver_pyserial.driver import PySerial @@ -167,7 +167,9 @@ def test_drivers_pi_pico_dump_not_implemented(monkeypatch, tmp_path): ) with pytest.raises(NotImplementedError, match="not supported"): - asyncio.run(driver.dump(None, None)) + async def _run(): + await driver.dump(None, None) + anyio.run(_run) def test_drivers_pi_pico_enter_bootloader_via_gpio(monkeypatch, tmp_path): diff --git a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver_test.py b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver_test.py index f373fa117..afd9870de 100644 --- a/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver_test.py +++ b/python/packages/jumpstarter-driver-pyserial/jumpstarter_driver_pyserial/driver_test.py @@ -220,7 +220,14 @@ def test_close_noop_when_no_stream(): def test_close_closes_transport(monkeypatch): - """close() should close the underlying transport.""" + """close() should close the underlying transport. + + These tests intentionally use asyncio.StreamReader/StreamWriter because the + serial_asyncio library is built on asyncio transports and protocols. This is + a known exception to the anyio migration -- serial_asyncio has no anyio + equivalent, so the bridge layer between serial I/O and anyio streams relies + on asyncio internals. + """ import asyncio from unittest.mock import MagicMock diff --git a/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver.py b/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver.py index d22753d28..832cadccb 100644 --- a/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver.py +++ b/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import json import logging import os @@ -15,9 +14,21 @@ from tempfile import TemporaryDirectory from typing import Literal +import anyio import yaml -from anyio import fail_after, run_process, sleep +from anyio import ( + IncompleteRead, + create_memory_object_stream, + create_task_group, + fail_after, + move_on_after, + run_process, + sleep, +) +from anyio.abc import ByteReceiveStream +from anyio.streams.buffered import BufferedByteReceiveStream from anyio.streams.file import FileReadStream, FileWriteStream +from anyio.streams.memory import MemoryObjectSendStream from jumpstarter_driver_network.driver import TcpNetwork, UnixNetwork, VsockNetwork from jumpstarter_driver_opendal.driver import FlasherInterface from jumpstarter_driver_power.driver import PowerInterface, PowerReading @@ -46,13 +57,23 @@ def filter(self, record): return False -async def _read_pipe(stream: asyncio.StreamReader, name: str, queue: asyncio.Queue): - while True: - line = await stream.readline() - if not line: - break - await queue.put((name, line.decode("utf-8", errors="replace"))) - await queue.put((name, None)) +async def _read_pipe( + stream: ByteReceiveStream, + name: str, + send_stream: MemoryObjectSendStream[tuple[str, str | None]], +) -> None: + buffered = BufferedByteReceiveStream(stream) + try: + while True: + line = await buffered.receive_until(b"\n", 1048576) + await send_stream.send((name, (line + b"\n").decode("utf-8", errors="replace"))) + except IncompleteRead: + remaining = buffered.buffer + if remaining: + await send_stream.send((name, remaining.decode("utf-8", errors="replace"))) + except (anyio.EndOfStream, anyio.ClosedResourceError): + pass + await send_stream.send((name, None)) @dataclass(kw_only=True) @@ -78,7 +99,6 @@ async def flash(self, source, partition: str | None = None): async with await FileWriteStream.from_path(self.parent.validate_partition(partition)) as stream: async with self.resource(source) as res: - # Wrap with auto-decompression to handle .gz, .xz, .bz2, .zstd files async for chunk in AutoDecompressIterator(source=res): await stream.send(chunk) @@ -105,12 +125,10 @@ async def flash_oci( if not oci_url.startswith("oci://"): raise ValueError(f"OCI URL must start with oci://, got: {oci_url}") - # If explicit credentials were provided, validate immediately if oci_username or oci_password: if bool(oci_username) != bool(oci_password): raise ValueError("OCI authentication requires both username and password") else: - # Fall back to env vars, then container auth files from jumpstarter.common.oci import resolve_oci_credentials oci_username, oci_password = resolve_oci_credentials(oci_url) @@ -147,61 +165,65 @@ async def _stream_subprocess( self, cmd: list[str], env: dict[str, str] | None ) -> AsyncGenerator[tuple[str, str, int | None], None]: """Run a subprocess and yield (stdout, stderr, returncode) tuples as output arrives.""" - process = await asyncio.create_subprocess_exec( # ty: ignore[missing-argument] - *cmd, - stdout=asyncio.subprocess.PIPE, # ty: ignore[unresolved-attribute] - stderr=asyncio.subprocess.PIPE, # ty: ignore[unresolved-attribute] - env=env, - ) + process = await anyio.open_process(cmd, stdout=PIPE, stderr=PIPE, env=env) - output_queue: asyncio.Queue[tuple[str, str | None]] = asyncio.Queue() + send_stream, receive_stream = create_memory_object_stream[tuple[str, str | None]](32) # ty: ignore[call-non-callable] + deferred_error: RuntimeError | None = None - tasks = [ - asyncio.create_task(_read_pipe(process.stdout, "stdout", output_queue)), - asyncio.create_task(_read_pipe(process.stderr, "stderr", output_queue)), - ] + async with send_stream, receive_stream: + async with create_task_group() as tg: + tg.start_soon(_read_pipe, process.stdout, "stdout", send_stream.clone()) + tg.start_soon(_read_pipe, process.stderr, "stderr", send_stream.clone()) - finished_streams = 0 - start_time = asyncio.get_running_loop().time() + finished_streams = 0 + start_time = anyio.current_time() - try: - while finished_streams < 2: - elapsed = asyncio.get_running_loop().time() - start_time - if elapsed >= self.parent.flash_timeout: - process.kill() - await process.wait() - raise RuntimeError(f"fls flash timed out after {self.parent.flash_timeout}s") - - remaining = self.parent.flash_timeout - elapsed try: - name, text = await asyncio.wait_for(output_queue.get(), timeout=min(remaining, 30)) - except asyncio.TimeoutError: - continue - - if text is None: - finished_streams += 1 - continue - - stdout_chunk = text if name == "stdout" else "" - stderr_chunk = text if name == "stderr" else "" - yield stdout_chunk, stderr_chunk, None - - await process.wait() - returncode = process.returncode - - if returncode != 0: - self.logger.error(f"fls failed - return code: {returncode}") - raise RuntimeError(f"fls flash failed (return code {returncode})") - - self.logger.info("OCI flash completed successfully") - yield "", "", returncode - finally: - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - if process.returncode is None: - process.kill() - await process.wait() + while finished_streams < 2: + elapsed = anyio.current_time() - start_time + if elapsed >= self.parent.flash_timeout: + process.kill() + await process.wait() + deferred_error = RuntimeError( + f"fls flash timed out after {self.parent.flash_timeout}s" + ) + break + + remaining = self.parent.flash_timeout - elapsed + with move_on_after(min(remaining, 30)) as scope: + name, text = await receive_stream.receive() + + if scope.cancelled_caught: + continue + + if text is None: + finished_streams += 1 + continue + + stdout_chunk = text if name == "stdout" else "" + stderr_chunk = text if name == "stderr" else "" + yield stdout_chunk, stderr_chunk, None + + if deferred_error is None: + await process.wait() + returncode = process.returncode + + if returncode != 0: + self.logger.error(f"fls failed - return code: {returncode}") + deferred_error = RuntimeError( + f"fls flash failed (return code {returncode})" + ) + else: + self.logger.info("OCI flash completed successfully") + yield "", "", returncode + finally: + tg.cancel_scope.cancel() + if process.returncode is None: + process.kill() + await process.wait() + + if deferred_error is not None: + raise deferred_error @export async def dump(self, target, partition: str | None = None): @@ -329,7 +351,6 @@ async def on(self) -> None: # noqa: C901 image_driver = "raw" current_virtual_size = root.stat().st_size - # Resize disk if configured if self.parent.disk_size: requested = self.parent._parse_size(self.parent.disk_size) @@ -432,7 +453,7 @@ class Qemu(Driver): smp: int = 2 mem: str = "512M" - disk_size: str | None = None # e.g., "20G" (resize disk before boot) + disk_size: str | None = None hostname: str = "demo" username: str = "jumpstarter" @@ -442,11 +463,10 @@ class Qemu(Driver): hostfwd: dict[str, Hostfwd] = field(default_factory=dict) - # FLS configuration for OCI flashing fls_version: str | None = field(default=None) fls_allow_custom_binaries: bool = field(default=False) fls_custom_binary_url: str | None = field(default=None) - flash_timeout: int = field(default=30 * 60) # 30 minutes + flash_timeout: int = field(default=30 * 60) _tmp_dir: TemporaryDirectory = field(init=False, default_factory=TemporaryDirectory) @@ -569,12 +589,12 @@ def _parse_size(self, size: str) -> int: @validate_call(validate_return=True) def set_disk_size(self, size: str) -> None: """Set the disk size for resizing before boot.""" - self._parse_size(size) # Validate + self._parse_size(size) self.disk_size = size @export @validate_call(validate_return=True) def set_memory_size(self, size: str) -> None: """Set the memory size for next boot.""" - self._parse_size(size) # Validate + self._parse_size(size) self.mem = size diff --git a/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver_test.py b/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver_test.py index 129e144ee..91123e81f 100644 --- a/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver_test.py +++ b/python/packages/jumpstarter-driver-qemu/jumpstarter_driver_qemu/driver_test.py @@ -1,4 +1,3 @@ -import asyncio import json import os import platform @@ -8,6 +7,8 @@ from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch +import anyio +import anyio.abc import pytest import requests from opendal import Operator @@ -209,8 +210,27 @@ def test_set_memory_size_invalid(): # OCI Flash Tests +def _create_mock_stream(lines): + """Create a mock byte receive stream that yields lines then raises EndOfStream.""" + encoded = [line.encode() if isinstance(line, str) else line for line in lines] + data = b"".join(encoded) + + stream = MagicMock(spec=anyio.abc.ByteReceiveStream) + call_count = {"n": 0} + + async def mock_receive(max_bytes=65536): + if call_count["n"] >= 1: + raise anyio.EndOfStream() + call_count["n"] += 1 + return data + + stream.receive = mock_receive + stream.aclose = AsyncMock() + return stream + + def _create_mock_process(stdout_lines=None, stderr_lines=None, returncode=0): - """Create a mock asyncio subprocess process for testing flash_oci.""" + """Create a mock subprocess process for testing flash_oci.""" if stdout_lines is None: stdout_lines = [] if stderr_lines is None: @@ -221,15 +241,8 @@ def _create_mock_process(stdout_lines=None, stderr_lines=None, returncode=0): process.wait = AsyncMock(return_value=returncode) process.kill = MagicMock() - stdout_data = [line.encode() if isinstance(line, str) else line for line in stdout_lines] + [b""] - stdout_stream = MagicMock() - stdout_stream.readline = AsyncMock(side_effect=stdout_data) - process.stdout = stdout_stream - - stderr_data = [line.encode() if isinstance(line, str) else line for line in stderr_lines] + [b""] - stderr_stream = MagicMock() - stderr_stream.readline = AsyncMock(side_effect=stderr_data) - process.stderr = stderr_stream + process.stdout = _create_mock_stream(stdout_lines) + process.stderr = _create_mock_stream(stderr_lines) return process @@ -251,18 +264,20 @@ async def test_flash_oci_success(): mock_process = _create_mock_process(stdout_lines=["Flashing complete\n"]) with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="/usr/local/bin/fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process) as mock_exec: + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ) as mock_exec: results = await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") # Verify final chunk has returncode 0 assert any(r[2] == 0 for r in results) mock_exec.assert_called_once() - call_args = mock_exec.call_args - assert call_args.args[0] == "/usr/local/bin/fls" - assert call_args.args[1] == "from-url" - assert call_args.args[2] == "oci://quay.io/org/image:tag" - assert call_args.args[3] == expected_target + cmd = mock_exec.call_args.args[0] + assert cmd[0] == "/usr/local/bin/fls" + assert cmd[1] == "from-url" + assert cmd[2] == "oci://quay.io/org/image:tag" + assert cmd[3] == expected_target @pytest.mark.anyio @@ -274,10 +289,12 @@ async def test_flash_oci_with_partition(): mock_process = _create_mock_process() with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process) as mock_exec: + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ) as mock_exec: await _collect_flash_oci(flasher, "oci://quay.io/org/bios:v1", partition="bios") - assert mock_exec.call_args.args[3] == expected_target + assert mock_exec.call_args.args[0][3] == expected_target @pytest.mark.anyio @@ -288,7 +305,9 @@ async def test_flash_oci_with_credentials(): mock_process = _create_mock_process() with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process) as mock_exec: + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ) as mock_exec: await _collect_flash_oci( flasher, "oci://quay.io/private/image:tag", @@ -319,7 +338,9 @@ async def test_flash_oci_no_credentials(): with patch("jumpstarter.common.oci.read_auth_file_credentials", return_value=(None, None)): with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): with patch( - "asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process + "jumpstarter_driver_qemu.driver.anyio.open_process", + new_callable=AsyncMock, + return_value=mock_process, ) as mock_exec: await _collect_flash_oci(flasher, "oci://quay.io/public/image:tag") @@ -337,7 +358,7 @@ async def test_flash_oci_credentials_from_env(): with patch.dict(os.environ, {"OCI_USERNAME": "envuser", "OCI_PASSWORD": "envpass"}): with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): with patch( - "asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process ) as mock_exec: await _collect_flash_oci(flasher, "oci://quay.io/private/image:tag") @@ -356,7 +377,7 @@ async def test_flash_oci_explicit_credentials_override_env(): with patch.dict(os.environ, {"OCI_USERNAME": "envuser", "OCI_PASSWORD": "envpass"}): with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): with patch( - "asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process ) as mock_exec: await _collect_flash_oci( flasher, @@ -381,7 +402,9 @@ async def test_flash_oci_streams_output(): ) with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process): + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ): results = await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") # Should have received streaming output plus the final returncode chunk @@ -432,7 +455,9 @@ async def test_flash_oci_fls_failure(): mock_process = _create_mock_process(returncode=1) with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process): + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ): with pytest.raises(RuntimeError, match="fls flash failed"): await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") @@ -443,8 +468,8 @@ async def test_flash_oci_fls_timeout(): driver = Qemu(flash_timeout=0) # Immediate timeout flasher = driver.children["flasher"] - async def hanging_readline(): - await asyncio.sleep(10) + async def hanging_receive(max_bytes=65536): + await anyio.sleep(10) return b"" mock_process = MagicMock() @@ -457,16 +482,20 @@ async def mock_wait(): mock_process.wait = mock_wait mock_process.kill = MagicMock() - stdout_stream = MagicMock() - stdout_stream.readline = hanging_readline + stdout_stream = MagicMock(spec=anyio.abc.ByteReceiveStream) + stdout_stream.receive = hanging_receive + stdout_stream.aclose = AsyncMock() mock_process.stdout = stdout_stream - stderr_stream = MagicMock() - stderr_stream.readline = hanging_readline + stderr_stream = MagicMock(spec=anyio.abc.ByteReceiveStream) + stderr_stream.receive = hanging_receive + stderr_stream.aclose = AsyncMock() mock_process.stderr = stderr_stream with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process): + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ): with pytest.raises(RuntimeError, match="fls flash timed out"): await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") @@ -475,30 +504,22 @@ async def mock_wait(): @pytest.mark.anyio async def test_flash_oci_inner_wait_timeout(): - """Inner wait_for timeout should continue the loop without raising.""" + """move_on_after timeout should continue the loop without raising.""" driver = Qemu(flash_timeout=600) flasher = driver.children["flasher"] mock_process = _create_mock_process(stdout_lines=["output\n"]) - original_wait_for = asyncio.wait_for - timeout_fired = False - - async def mock_wait_for(awaitable, *, timeout): - nonlocal timeout_fired - if not timeout_fired: # ty: ignore[unresolved-reference] - timeout_fired = True - if hasattr(awaitable, "close"): - awaitable.close() - raise asyncio.TimeoutError() - return await original_wait_for(awaitable, timeout=timeout) - with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process): - with patch("asyncio.wait_for", mock_wait_for): - results = await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ): + results = await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") - assert timeout_fired - assert any(r[2] == 0 for r in results) + final_results = [r for r in results if r[2] is not None] + assert len(final_results) == 1, "exactly one final result with returncode expected" + assert final_results[0][2] == 0 + stdout_chunks = [r[0] for r in results if r[0]] + assert len(stdout_chunks) > 0, "output data should have been received" @pytest.mark.anyio @@ -517,19 +538,19 @@ async def mock_wait(): mock_process.wait = mock_wait mock_process.kill = MagicMock() - stdout_stream = MagicMock() - stdout_stream.readline = AsyncMock(side_effect=[b"line1\n", b"line2\n", b""]) - mock_process.stdout = stdout_stream + mock_process.stdout = _create_mock_stream(["line1\n", "line2\n"]) + mock_process.stderr = _create_mock_stream([]) - stderr_stream = MagicMock() - stderr_stream.readline = AsyncMock(side_effect=[b""]) - mock_process.stderr = stderr_stream - - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process): + with patch("jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process): gen = flasher._stream_subprocess(["fls", "from-url", "oci://img", "/tmp/root"], None) # ty: ignore[unresolved-attribute] async for _ in gen: break - await gen.aclose() + # GeneratorExit inside an anyio task group is wrapped in a + # BaseExceptionGroup; this is expected structured concurrency behavior. + try: + await gen.aclose() + except BaseExceptionGroup: + pass mock_process.kill.assert_called() @@ -541,7 +562,9 @@ async def test_flash_oci_fls_not_found(): flasher = driver.children["flasher"] with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, side_effect=FileNotFoundError): + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, side_effect=FileNotFoundError + ): with pytest.raises(RuntimeError, match="fls command not found"): await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") @@ -554,7 +577,9 @@ async def test_flash_oci_uses_fls_config(): mock_process = _create_mock_process() with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls") as mock_get: - with patch("asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process): + with patch( + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process + ): await _collect_flash_oci(flasher, "oci://quay.io/org/image:tag") mock_get.assert_called_once_with( @@ -586,13 +611,14 @@ def test_flash_oci_via_flasher_client(): with serve(Qemu()) as qemu: with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): with patch( - "asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process ) as mock_exec: qemu.flasher.flash("oci://quay.io/org/image:tag") mock_exec.assert_called_once() - assert mock_exec.call_args.args[1] == "from-url" - assert mock_exec.call_args.args[2] == "oci://quay.io/org/image:tag" + cmd = mock_exec.call_args.args[0] + assert cmd[1] == "from-url" + assert cmd[2] == "oci://quay.io/org/image:tag" def test_flash_oci_convenience_method(): @@ -602,14 +628,15 @@ def test_flash_oci_convenience_method(): with serve(Qemu()) as qemu: with patch("jumpstarter_driver_qemu.driver.get_fls_binary", return_value="fls"): with patch( - "asyncio.create_subprocess_exec", new_callable=AsyncMock, return_value=mock_process + "jumpstarter_driver_qemu.driver.anyio.open_process", new_callable=AsyncMock, return_value=mock_process ) as mock_exec: qemu.flash_oci("oci://quay.io/org/image:tag", partition="bios") mock_exec.assert_called_once() - assert mock_exec.call_args.args[1] == "from-url" - assert mock_exec.call_args.args[2] == "oci://quay.io/org/image:tag" - assert Path(mock_exec.call_args.args[3]).name == "bios" + cmd = mock_exec.call_args.args[0] + assert cmd[1] == "from-url" + assert cmd[2] == "oci://quay.io/org/image:tag" + assert Path(cmd[3]).name == "bios" @pytest.mark.anyio diff --git a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver.py b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver.py index d4e6a6e2c..f1bed9b21 100644 --- a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver.py +++ b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver.py @@ -1,4 +1,3 @@ -import asyncio import os import subprocess import time @@ -6,6 +5,7 @@ from pathlib import Path from typing import Dict +import anyio from jumpstarter_driver_opendal.driver import Opendal from jumpstarter.common.exceptions import ConfigurationError @@ -362,7 +362,7 @@ async def boot_to_fastboot(self): chunk = await stream.receive() data += chunk self.logger.debug(f"prompt returned after command: {command}") - await asyncio.sleep(delay) + await anyio.sleep(delay) self.logger.info("device should now be in fastboot mode") @@ -408,7 +408,7 @@ async def cycle(self, delay: float = 2): """Power cycle the device""" self.logger.info(f"Power cycling device with {delay}s delay") await self.off() - await asyncio.sleep(delay) + await anyio.sleep(delay) await self.on() @export @@ -434,4 +434,4 @@ async def _send_power_commands_sequence(serial, logger, commands): for command, delay in commands: await _send_power_command(serial, logger, command) if delay > 0: - await asyncio.sleep(delay) + await anyio.sleep(delay) diff --git a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver_test.py b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver_test.py index c46328700..6b4b310f8 100644 --- a/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver_test.py +++ b/python/packages/jumpstarter-driver-ridesx/jumpstarter_driver_ridesx/driver_test.py @@ -433,12 +433,12 @@ def test_power_off_exported(ridesx_power_driver): assert inspect.iscoroutinefunction(ridesx_power_driver.off) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_power_cycle(ridesx_power_driver): """Test power cycle calls off, waits, then on""" with patch.object(ridesx_power_driver, "off", new_callable=AsyncMock) as mock_off: with patch.object(ridesx_power_driver, "on", new_callable=AsyncMock) as mock_on: - with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + with patch("jumpstarter_driver_ridesx.driver.anyio.sleep", new_callable=AsyncMock) as mock_sleep: await ridesx_power_driver.cycle(delay=0.1) mock_off.assert_called_once() diff --git a/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py b/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py index ca6610377..25ecc6047 100644 --- a/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py +++ b/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver.py @@ -1,16 +1,15 @@ -import asyncio -import asyncio.subprocess import os import signal import subprocess from dataclasses import dataclass, field +from subprocess import PIPE from typing import AsyncGenerator +import anyio +from anyio import move_on_after + from jumpstarter.driver import Driver, export -# Environment variables that are blocked because they allow privilege escalation. -# A client that can set these can hijack the subprocess (e.g. LD_PRELOAD to load -# arbitrary shared libraries, PATH to redirect commands to attacker binaries). BLOCKED_ENV_VARS: set[str] = { "LD_PRELOAD", "LD_LIBRARY_PATH", @@ -21,7 +20,6 @@ "HOME", } -# Prefixes that are also blocked (matched with str.startswith). BLOCKED_ENV_PREFIXES: tuple[str, ...] = ( "LD_", "BASH_FUNC_", @@ -32,10 +30,6 @@ class Shell(Driver): """shell driver for Jumpstarter""" - # methods field defines the methods exported and their shell scripts - # Supports two formats: - # 1. Simple string: method_name: "command" - # 2. Dict with description: method_name: {command: "...", description: "...", timeout: ...} methods: dict[str, str | dict[str, str | int]] shell: list[str] = field(default_factory=lambda: ["bash", "-c"]) timeout: int = 300 @@ -43,7 +37,6 @@ class Shell(Driver): def __post_init__(self): super().__post_init__() - # Extract descriptions from methods configuration and populate methods_description for method_name, method_config in self.methods.items(): if isinstance(method_config, dict) and "description" in method_config: self.methods_description[method_name] = method_config["description"] @@ -105,10 +98,8 @@ async def call_method(self, method: str, env, *args) -> AsyncGenerator[tuple[str def _validate_script_params(self, script, args, env_vars): """Validate script parameters and return combined environment.""" - # Merge parent environment with the user-supplied env_vars combined_env = os.environ.copy() if env_vars: - # Validate environment variable names for key in env_vars: if not isinstance(key, str) or not key.isidentifier(): raise ValueError(f"Invalid environment variable name: {key}") @@ -121,17 +112,39 @@ def _validate_script_params(self, script, args, env_vars): if not isinstance(script, str) or not script.strip(): raise ValueError("Shell script must be a non-empty string") - # Validate arguments for arg in args: if not isinstance(arg, str): raise ValueError(f"All arguments must be strings, got {type(arg)}") - # Validate working directory if set if self.cwd and not os.path.isdir(self.cwd): raise ValueError(f"Working directory does not exist: {self.cwd}") return combined_env + @staticmethod + async def _read_stream(stream, read_all: bool) -> str: + """Read from a single byte stream and return decoded text.""" + if stream is None: + return "" + try: + if read_all: + chunks = [] + try: + while True: + chunks.append(await stream.receive()) + except anyio.EndOfStream: + pass + chunk = b"".join(chunks) + else: + chunk = None + with move_on_after(0.01): + chunk = await stream.receive(1024) + if chunk: + return chunk.decode('utf-8', errors='replace') + except (anyio.EndOfStream, anyio.ClosedResourceError): + pass + return "" + async def _read_process_output(self, process, read_all=False): """Read data from stdout and stderr streams. @@ -139,33 +152,8 @@ async def _read_process_output(self, process, read_all=False): :param read_all: If True, read all remaining data. If False, read with timeout. :return: Tuple of (stdout_data, stderr_data) """ - stdout_data = "" - stderr_data = "" - - # Read from stdout - if process.stdout: - try: - if read_all: - chunk = await process.stdout.read() - else: - chunk = await asyncio.wait_for(process.stdout.read(1024), timeout=0.01) - if chunk: - stdout_data = chunk.decode('utf-8', errors='replace') - except (asyncio.TimeoutError, Exception): - pass - - # Read from stderr - if process.stderr: - try: - if read_all: - chunk = await process.stderr.read() - else: - chunk = await asyncio.wait_for(process.stderr.read(1024), timeout=0.01) - if chunk: - stderr_data = chunk.decode('utf-8', errors='replace') - except (asyncio.TimeoutError, Exception): - pass - + stdout_data = await self._read_stream(process.stdout, read_all) + stderr_data = await self._read_stream(process.stderr, read_all) return stdout_data, stderr_data async def _run_inline_shell_script( @@ -186,56 +174,53 @@ async def _run_inline_shell_script( combined_env = self._validate_script_params(script, args, env_vars) cmd = self.shell + [script, method] + list(args) - # Start the process with pipes for streaming and new process group - self.logger.debug( f"running {method} with cmd: {cmd} and env: {combined_env} " f"and args: {args}") - process = await asyncio.create_subprocess_exec( # ty: ignore[missing-argument] - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, + self.logger.debug(f"running {method} with cmd: {cmd} and env: {combined_env} and args: {args}") + process = await anyio.open_process( + cmd, + stdout=PIPE, + stderr=PIPE, env=combined_env, cwd=self.cwd, - start_new_session=True, # Create new process group + start_new_session=True, ) - # Create a task to monitor the process timeout - start_time = asyncio.get_event_loop().time() + start_time = anyio.current_time() if timeout is None: timeout = self.timeout - # Read output in real-time - while process.returncode is None: - if asyncio.get_event_loop().time() - start_time > timeout: - # Send SIGTERM to entire process group for graceful termination - try: - os.killpg(process.pid, signal.SIGTERM) - except (ProcessLookupError, OSError): - # Process group might already be gone - pass - try: - await asyncio.wait_for(process.wait(), timeout=5.0) - except asyncio.TimeoutError: + try: + while process.returncode is None: + if anyio.current_time() - start_time > timeout: try: - os.killpg(process.pid, signal.SIGKILL) - self.logger.warning(f"SIGTERM failed to terminate {process.pid}, sending SIGKILL") + os.killpg(process.pid, signal.SIGTERM) except (ProcessLookupError, OSError): pass - raise subprocess.TimeoutExpired(cmd, timeout) from None + with move_on_after(5.0): + await process.wait() + if process.returncode is None: + try: + os.killpg(process.pid, signal.SIGKILL) + self.logger.warning(f"SIGTERM failed to terminate {process.pid}, sending SIGKILL") + except (ProcessLookupError, OSError): + pass + raise subprocess.TimeoutExpired(cmd, timeout) from None - try: - stdout_data, stderr_data = await self._read_process_output(process, read_all=False) + try: + stdout_data, stderr_data = await self._read_process_output(process, read_all=False) - # Yield any data we got - if stdout_data or stderr_data: - yield stdout_data, stderr_data, None + if stdout_data or stderr_data: + yield stdout_data, stderr_data, None - # Small delay to prevent busy waiting - await asyncio.sleep(0.1) + await anyio.sleep(0.1) - except Exception: - break + except (anyio.EndOfStream, anyio.ClosedResourceError): + break - # Process completed, get return code and final output - returncode = process.returncode - remaining_stdout, remaining_stderr = await self._read_process_output(process, read_all=True) - yield remaining_stdout, remaining_stderr, returncode + returncode = process.returncode + remaining_stdout, remaining_stderr = await self._read_process_output(process, read_all=True) + yield remaining_stdout, remaining_stderr, returncode + finally: + if process.returncode is None: + process.kill() + await process.wait() diff --git a/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py b/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py index 0bff21a32..160744cd3 100644 --- a/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py +++ b/python/packages/jumpstarter-driver-shell/jumpstarter_driver_shell/driver_test.py @@ -1,4 +1,7 @@ +from unittest.mock import patch + +import anyio import pytest from .driver import Shell @@ -279,3 +282,62 @@ def test_mixed_format_methods(): assert cli.commands['simple'].help == "Execute the simple shell method" assert cli.commands['detailed'].help == "A detailed command with description" assert cli.commands['default_cmd'].help == "Method using default command" + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +@pytest.mark.anyio +async def test_unexpected_exception_propagates_from_streaming_loop(): + """Verify that non-stream exceptions propagate instead of being silently swallowed. + + The streaming loop in _run_inline_shell_script should only catch + anyio.EndOfStream and anyio.ClosedResourceError. Other exceptions + like RuntimeError must propagate to the caller. + """ + shell = Shell(methods={"sleeper": "sleep 10"}) + + call_count = 0 + + original_read = shell._read_process_output + + async def failing_read(process, read_all=False): + nonlocal call_count + if read_all: + return await original_read(process, read_all) + call_count += 1 + if call_count >= 2: + raise RuntimeError("simulated unexpected failure") + return await original_read(process, read_all) + + with patch.object(shell, "_read_process_output", side_effect=failing_read): + with pytest.raises(RuntimeError, match="simulated unexpected failure"): + async for _ in shell._run_inline_shell_script("sleeper", "sleep 10"): + pass + + +@pytest.mark.anyio +async def test_stream_exceptions_cause_clean_exit(): + """Verify that anyio.EndOfStream causes a clean loop exit, not an error.""" + shell = Shell(methods={"sleeper": "echo done"}) + + call_count = 0 + + original_read = shell._read_process_output + + async def eos_read(process, read_all=False): + nonlocal call_count + if read_all: + return await original_read(process, read_all) + call_count += 1 + if call_count >= 2: + raise anyio.EndOfStream() + return await original_read(process, read_all) + + with patch.object(shell, "_read_process_output", side_effect=eos_read): + results = [] + async for chunk in shell._run_inline_shell_script("sleeper", "echo done"): + results.append(chunk) + assert len(results) >= 1 diff --git a/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py b/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py index dac94f642..d7b4cbf4c 100644 --- a/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py +++ b/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver.py @@ -1,5 +1,5 @@ -import asyncio import socket +from asyncio import AbstractEventLoop, Event, TimeoutError, get_running_loop, new_event_loop, set_event_loop, wait_for from dataclasses import dataclass, field from enum import Enum, IntEnum from typing import Any, Dict, Tuple @@ -117,7 +117,7 @@ def _setup_snmp(self): def client(cls) -> str: return "jumpstarter_driver_snmp.client.SNMPServerClient" - def _create_snmp_callback(self, result: Dict[str, Any], response_received: asyncio.Event): + def _create_snmp_callback(self, result: Dict[str, Any], response_received: Event): def callback(snmpEngine, sendRequestHandle, errorIndication, errorStatus, errorIndex, varBinds, cbCtx): self.logger.debug(f"Callback {errorIndication} {errorStatus} {errorIndex} {varBinds}") if errorIndication: @@ -138,23 +138,23 @@ def callback(snmpEngine, sendRequestHandle, errorIndication, errorStatus, errorI return callback - def _setup_event_loop(self) -> Tuple[asyncio.AbstractEventLoop, bool]: + def _setup_event_loop(self) -> Tuple[AbstractEventLoop, bool]: try: - loop = asyncio.get_running_loop() + loop = get_running_loop() return loop, False except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = new_event_loop() + set_event_loop(loop) return loop, True - async def _run_snmp_dispatcher(self, snmp_engine: engine.SnmpEngine, response_received: asyncio.Event): + async def _run_snmp_dispatcher(self, snmp_engine: engine.SnmpEngine, response_received: Event): snmp_engine.open_dispatcher() await response_received.wait() snmp_engine.close_dispatcher() def _snmp_set(self, state: PowerState): result = {"success": False, "error": None} - response_received = asyncio.Event() + response_received = Event() loop = None created_loop = False @@ -174,8 +174,8 @@ def _snmp_set(self, state: PowerState): dispatcher_task = loop.create_task(self._run_snmp_dispatcher(snmp_engine, response_received)) try: - loop.run_until_complete(asyncio.wait_for(dispatcher_task, self.timeout)) - except asyncio.TimeoutError: + loop.run_until_complete(wait_for(dispatcher_task, self.timeout)) + except TimeoutError: self.logger.warning(f"SNMP operation timed out after {self.timeout} seconds") result["error"] = "SNMP operation timed out" diff --git a/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver_test.py b/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver_test.py index 7d1d38eec..7bf29f229 100644 --- a/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver_test.py +++ b/python/packages/jumpstarter-driver-snmp/jumpstarter_driver_snmp/driver_test.py @@ -35,7 +35,7 @@ def setup_mock_snmp_engine(): "auth_key": None, "priv_protocol": PrivProtocol.NONE, "priv_key": None, - "expected_args_len": 2, # only user and engine args for noAuth + "expected_args_len": 2, }, { "user": "usr-md5-none", @@ -43,7 +43,7 @@ def setup_mock_snmp_engine(): "auth_key": "authkey1", "priv_protocol": PrivProtocol.NONE, "priv_key": None, - "expected_args_len": 4, # engine, user, auth_protocol, auth_key + "expected_args_len": 4, }, { "user": "usr-sha-des", @@ -51,7 +51,7 @@ def setup_mock_snmp_engine(): "auth_key": "authkey1", "priv_protocol": PrivProtocol.DES, "priv_key": "privkey1", - "expected_args_len": 6, # engine, user, auth_protocol, auth_key, priv_protocol, priv_key + "expected_args_len": 6, }, ], ) @@ -109,9 +109,9 @@ def test_power_on_command(mock_engine, mock_add_user): with ( patch("pysnmp.entity.rfc3413.cmdgen.SetCommandGenerator.send_varbinds") as mock_send, - patch("asyncio.get_running_loop", side_effect=RuntimeError), - patch("asyncio.new_event_loop"), - patch("asyncio.set_event_loop"), + patch("jumpstarter_driver_snmp.driver.get_running_loop", side_effect=RuntimeError), + patch("jumpstarter_driver_snmp.driver.new_event_loop"), + patch("jumpstarter_driver_snmp.driver.set_event_loop"), patch("pysnmp.entity.config.add_target_parameters"), patch("pysnmp.entity.config.add_target_address"), patch("pysnmp.entity.config.add_transport"), @@ -137,9 +137,9 @@ def test_power_off_command(mock_engine, mock_add_user): with ( patch("pysnmp.entity.rfc3413.cmdgen.SetCommandGenerator.send_varbinds") as mock_send, - patch("asyncio.get_running_loop", side_effect=RuntimeError), - patch("asyncio.new_event_loop"), - patch("asyncio.set_event_loop"), + patch("jumpstarter_driver_snmp.driver.get_running_loop", side_effect=RuntimeError), + patch("jumpstarter_driver_snmp.driver.new_event_loop"), + patch("jumpstarter_driver_snmp.driver.set_event_loop"), patch("pysnmp.entity.config.add_target_parameters"), patch("pysnmp.entity.config.add_target_address"), patch("pysnmp.entity.config.add_transport"), diff --git a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 9e3b4de77..e84bf1c98 100644 --- a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -1,9 +1,10 @@ -import asyncio import os import threading +from asyncio import AbstractEventLoop, CancelledError, create_task, gather, new_event_loop, set_event_loop from dataclasses import dataclass, field from typing import Optional +import anyio from jumpstarter_driver_opendal.driver import Opendal from jumpstarter_driver_tftp.server import TftpServer @@ -39,12 +40,12 @@ class Tftp(Driver): root_dir: str = "/var/lib/tftpboot" host: str = field(default="") port: int = 69 - remove_created_on_close: bool = True # Clean up temporary boot files by default + remove_created_on_close: bool = True server: Optional["TftpServer"] = field(init=False, default=None) server_thread: Optional[threading.Thread] = field(init=False, default=None) _shutdown_event: threading.Event = field(init=False, default_factory=threading.Event) _loop_ready: threading.Event = field(init=False, default_factory=threading.Event) - _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) + _loop: Optional[AbstractEventLoop] = field(init=False, default=None) def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -67,8 +68,8 @@ def client(cls) -> str: return "jumpstarter_driver_tftp.client.TftpServerClient" def _start_server(self): - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) + self._loop = new_event_loop() + set_event_loop(self._loop) self.server = TftpServer( host=self.host, port=self.port, @@ -91,15 +92,15 @@ def _start_server(self): async def _run_server(self): try: - server_task = asyncio.create_task(self.server.start()) - await asyncio.gather(server_task, self._wait_for_shutdown()) - except asyncio.CancelledError: + server_task = create_task(self.server.start()) + await gather(server_task, self._wait_for_shutdown()) + except CancelledError: self.logger.info("Server task cancelled") raise async def _wait_for_shutdown(self): while not self._shutdown_event.is_set(): - await asyncio.sleep(0.1) + await anyio.sleep(0.1) self.logger.info("Shutdown event detected") if self.server is not None: await self.server.shutdown() diff --git a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py index 8d5a5fe78..738ca4d55 100644 --- a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py +++ b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py @@ -1,6 +1,17 @@ -import asyncio import logging import pathlib +from asyncio import ( + CancelledError, + DatagramProtocol, + DatagramTransport, + Event, + Task, + TimeoutError, + create_task, + gather, + get_running_loop, + wait_for, +) from enum import IntEnum from typing import Optional, Set, Tuple @@ -49,8 +60,8 @@ def __init__( self.timeout = timeout self.retries = retries self.active_transfers: Set["TftpTransfer"] = set() - self.shutdown_event = asyncio.Event() - self.transport: Optional[asyncio.DatagramTransport] = None + self.shutdown_event = Event() + self.transport: Optional[DatagramTransport] = None self.protocol: Optional["TftpServerProtocol"] = None if logger is not None: @@ -58,7 +69,7 @@ def __init__( else: self.logger = logging.getLogger(self.__class__.__name__) - self.ready_event = asyncio.Event() + self.ready_event = Event() @property def address(self) -> Optional[Tuple[str, int]]: @@ -69,7 +80,7 @@ def address(self) -> Optional[Tuple[str, int]]: async def start(self): self.logger.info(f"Starting TFTP server on {self.host}:{self.port}") - loop = asyncio.get_running_loop() + loop = get_running_loop() self.ready_event.set() self.transport, self.protocol = await loop.create_datagram_endpoint( @@ -78,7 +89,7 @@ async def start(self): try: await self.shutdown_event.wait() - except asyncio.CancelledError: + except CancelledError: pass finally: self.logger.info("TFTP server shutting down") @@ -90,7 +101,7 @@ async def _cleanup(self): # Cancel all active transfers cleanup_tasks = [transfer.cleanup() for transfer in self.active_transfers.copy()] if cleanup_tasks: - await asyncio.gather(*cleanup_tasks, return_exceptions=True) + await gather(*cleanup_tasks, return_exceptions=True) # Close the main transport if self.transport: @@ -112,17 +123,17 @@ def unregister_transfer(self, transfer: "TftpTransfer"): self.logger.debug(f"Unregistered transfer: {transfer}") -class TftpServerProtocol(asyncio.DatagramProtocol): +class TftpServerProtocol(DatagramProtocol): """ Protocol for handling incoming TFTP requests. """ def __init__(self, server: TftpServer): self.server = server - self.transport: Optional[asyncio.DatagramTransport] = None + self.transport: Optional[DatagramTransport] = None self.logger = server.logger.getChild(self.__class__.__name__) - def connection_made(self, transport: asyncio.DatagramTransport): + def connection_made(self, transport: DatagramTransport): self.transport = transport self.logger.debug("Server protocol connection established") @@ -150,7 +161,7 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): self.logger.debug(f"Received opcode {opcode.name} from {addr}") if opcode == Opcode.RRQ: - asyncio.create_task(self._handle_read_request(data, addr)) + create_task(self._handle_read_request(data, addr)) else: self.logger.warning(f"Unsupported opcode {opcode} from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported operation") @@ -316,7 +327,7 @@ async def _start_transfer( negotiated_options=negotiated_options, ) self.server.register_transfer(transfer) - asyncio.create_task(transfer.start()) + create_task(transfer.start()) @@ -340,9 +351,9 @@ def __init__( self.block_size = block_size self.timeout = timeout self.retries = retries - self.transport: Optional[asyncio.DatagramTransport] = None + self.transport: Optional[DatagramTransport] = None self.protocol: Optional["TftpTransferProtocol"] = None - self.cleanup_task: Optional[asyncio.Task] = None + self.cleanup_task: Optional[Task] = None self.logger = server.logger.getChild(self.__class__.__name__) async def start(self): @@ -378,7 +389,7 @@ def __init__( retries=retries, ) self.block_num = 0 - self.ack_received = asyncio.Event() + self.ack_received = Event() self.last_ack = 0 self.oack_confirmed = False self.negotiated_options = negotiated_options @@ -402,7 +413,7 @@ async def start(self): await self.cleanup() async def _initialize_transfer(self) -> bool: - loop = asyncio.get_running_loop() + loop = get_running_loop() self.transport, self.protocol = await loop.create_datagram_endpoint( lambda: TftpTransferProtocol(self), local_addr=("0.0.0.0", 0), remote_addr=self.client_addr @@ -507,7 +518,7 @@ async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, waiting for ACK (Attempt {attempt})" ) self.ack_received.clear() - await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout) + await wait_for(self.ack_received.wait(), timeout=self.timeout) if self.last_ack == expected_block: self.logger.debug(f"ACK received for block {expected_block}") @@ -515,7 +526,7 @@ async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool else: self.logger.warning(f"Received wrong ACK: expected {expected_block}, got {self.last_ack}") - except asyncio.TimeoutError: + except TimeoutError: self.logger.warning(f"Timeout waiting for ACK of block {expected_block} (Attempt {attempt})") return False @@ -540,7 +551,7 @@ def handle_ack(self, block_num: int): self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}") -class TftpTransferProtocol(asyncio.DatagramProtocol): +class TftpTransferProtocol(DatagramProtocol): """ Protocol for handling ACKs during a TFTP transfer. """ @@ -549,7 +560,7 @@ def __init__(self, transfer: TftpReadTransfer): self.transfer = transfer self.logger = transfer.logger.getChild(self.__class__.__name__) - def connection_made(self, transport: asyncio.DatagramTransport): + def connection_made(self, transport: DatagramTransport): self.transfer.transport = transport local_addr = transport.get_extra_info("sockname") self.logger.debug(f"Transfer protocol connection established on {local_addr} for {self.transfer.client_addr}") diff --git a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py index 3fd297e5d..35abeb6ed 100644 --- a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py +++ b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py @@ -1,5 +1,15 @@ -import asyncio import tempfile +from asyncio import ( + CancelledError, + DatagramProtocol, + all_tasks, + create_task, + current_task, + get_running_loop, +) +from asyncio import ( + sleep as asyncio_sleep, +) from pathlib import Path import pytest @@ -15,12 +25,12 @@ async def tftp_server(): test_file_path.write_text("Hello, TFTP!") server = TftpServer(host="127.0.0.1", port=0, operator=AsyncOperator("fs", root=str(temp_dir))) - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) for _ in range(10): if server.address is not None: break - await asyncio.sleep(0.1) + await asyncio_sleep(0.1) else: await server.shutdown() server_task.cancel() @@ -31,29 +41,29 @@ async def tftp_server(): await server.shutdown() await server_task - for task in asyncio.all_tasks(): - if not task.done() and task != asyncio.current_task(): + for task in all_tasks(): + if not task.done() and task != current_task(): task.cancel() try: await task - except asyncio.CancelledError: + except CancelledError: pass async def create_test_client(server_port): - loop = asyncio.get_running_loop() + loop = get_running_loop() transport, protocol = await loop.create_datagram_endpoint( - asyncio.DatagramProtocol, remote_addr=("127.0.0.1", server_port) + DatagramProtocol, remote_addr=("127.0.0.1", server_port) ) return transport, protocol -@pytest.mark.asyncio +@pytest.mark.anyio async def test_server_startup_and_shutdown(tftp_server): """Test that server starts up and shuts down cleanly.""" server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) await server.ready_event.wait() await server.shutdown() @@ -63,12 +73,12 @@ async def test_server_startup_and_shutdown(tftp_server): assert True -@pytest.mark.asyncio +@pytest.mark.anyio async def test_read_request_for_existing_file(tftp_server): """Test reading an existing file from the server.""" server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) await server.ready_event.wait() try: @@ -91,12 +101,12 @@ async def test_read_request_for_existing_file(tftp_server): await server_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_read_request_for_nonexistent_file(tftp_server): """Test reading a non-existent file returns appropriate error.""" server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) try: transport, protocol = await create_test_client(server_port) @@ -112,11 +122,11 @@ async def test_read_request_for_nonexistent_file(tftp_server): await server_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_write_request_rejection(tftp_server): """Test that write requests are properly rejected (server is read-only).""" server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) try: transport, _ = await create_test_client(server_port) @@ -132,10 +142,10 @@ async def test_write_request_rejection(tftp_server): await server_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_invalid_packet_handling(tftp_server): server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) await server.ready_event.wait() try: @@ -150,12 +160,12 @@ async def test_invalid_packet_handling(tftp_server): await server_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_path_traversal_prevention(tftp_server): """Test that path traversal attempts are blocked.""" server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) await server.ready_event.wait() try: @@ -173,11 +183,11 @@ async def test_path_traversal_prevention(tftp_server): await server_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_options_negotiation(tftp_server): """Test that options (blksize, timeout) are properly negotiated.""" server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) await server.ready_event.wait() try: @@ -204,7 +214,7 @@ async def test_options_negotiation(tftp_server): await server_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_retry_mechanism(tftp_server): server, _, server_port = tftp_server @@ -213,7 +223,7 @@ async def test_retry_mechanism(tftp_server): transport = None - class TestProtocol(asyncio.DatagramProtocol): + class TestProtocol(DatagramProtocol): def __init__(self): self.received_packets = [] self.transport = None @@ -225,7 +235,7 @@ def datagram_received(self, data, addr): self.received_packets.append(data) try: - loop = asyncio.get_running_loop() + loop = get_running_loop() transport, protocol = await loop.create_datagram_endpoint(lambda: TestProtocol(), local_addr=("127.0.0.1", 0)) assert transport is not None, "Failed to create transport" @@ -234,7 +244,7 @@ def datagram_received(self, data, addr): transport.sendto(rrq_packet, ("127.0.0.1", server_port)) - await asyncio.sleep(server.timeout * 2) + await asyncio_sleep(server.timeout * 2) data_packets = [p for p in protocol.received_packets if p[0:2] == Opcode.DATA.to_bytes(2, "big")] @@ -252,10 +262,10 @@ def datagram_received(self, data, addr): transport.close() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_invalid_options_handling(tftp_server): server, temp_dir, server_port = tftp_server - server_task = asyncio.create_task(server.start()) + server_task = create_task(server.start()) await server.ready_event.wait() try: diff --git a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/test_path_traversal.py b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/test_path_traversal.py index e64255e57..796c90d3f 100644 --- a/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/test_path_traversal.py +++ b/python/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/test_path_traversal.py @@ -23,27 +23,27 @@ def protocol(server): class TestResolveAndValidatePath: - @pytest.mark.asyncio + @pytest.mark.anyio async def test_rejects_dot_dot_path(self, protocol): result = await protocol._resolve_and_validate_path("..", ("127.0.0.1", 12345)) assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio async def test_rejects_dot_dot_prefix(self, protocol): result = await protocol._resolve_and_validate_path("../etc/passwd", ("127.0.0.1", 12345)) assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio async def test_rejects_dot_dot_in_middle(self, protocol): result = await protocol._resolve_and_validate_path("subdir/../../../etc/passwd", ("127.0.0.1", 12345)) assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio async def test_rejects_dot_dot_at_end(self, protocol): result = await protocol._resolve_and_validate_path("subdir/..", ("127.0.0.1", 12345)) assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio async def test_allows_valid_filename(self, protocol, server): stat_result = MagicMock() stat_result.mode.is_file.return_value = True @@ -52,7 +52,7 @@ async def test_allows_valid_filename(self, protocol, server): result = await protocol._resolve_and_validate_path("boot.img", ("127.0.0.1", 12345)) assert result == "boot.img" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_allows_filename_containing_dots(self, protocol, server): stat_result = MagicMock() stat_result.mode.is_file.return_value = True @@ -61,12 +61,12 @@ async def test_allows_filename_containing_dots(self, protocol, server): result = await protocol._resolve_and_validate_path("file..name.txt", ("127.0.0.1", 12345)) assert result == "file..name.txt" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_rejects_absolute_path(self, protocol): result = await protocol._resolve_and_validate_path("/etc/passwd", ("127.0.0.1", 12345)) assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio async def test_sends_access_violation_on_traversal(self, protocol): await protocol._resolve_and_validate_path("../secret", ("127.0.0.1", 12345)) protocol.transport.sendto.assert_called_once() diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients.py index 32d8c0d1b..f56debb83 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients.py @@ -1,8 +1,8 @@ -import asyncio import base64 import logging from typing import Literal, Optional +import anyio from kubernetes_asyncio.client.models import V1ObjectMeta, V1ObjectReference from pydantic import Field @@ -125,7 +125,7 @@ async def create_client( if "credential" in updated_client["status"]: return V1Alpha1Client.from_dict(updated_client) count += 1 - await asyncio.sleep(CREATE_CLIENT_DELAY) + await anyio.sleep(CREATE_CLIENT_DELAY) raise Exception("Timeout waiting for client credentials") async def list_clients(self) -> V1Alpha1List[V1Alpha1Client]: diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients_test.py index 3a3983c43..1bd24a1b2 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/clients_test.py @@ -239,7 +239,7 @@ def test_client_list_rich_add_names(): # Tests for get_ca_bundle and get_client_config -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_ca_bundle_with_ca_cert(): """Test get_ca_bundle returns base64-encoded CA certificate""" api = ClientsV1Alpha1Api(namespace="test-namespace") @@ -260,7 +260,7 @@ async def test_get_ca_bundle_with_ca_cert(): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_ca_bundle_empty_ca_cert(): """Test get_ca_bundle returns empty string when ca.crt is empty""" api = ClientsV1Alpha1Api(namespace="test-namespace") @@ -275,7 +275,7 @@ async def test_get_ca_bundle_empty_ca_cert(): assert result == "" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_ca_bundle_missing_ca_crt_key(): """Test get_ca_bundle returns empty string when ca.crt key is missing""" api = ClientsV1Alpha1Api(namespace="test-namespace") @@ -290,7 +290,7 @@ async def test_get_ca_bundle_missing_ca_crt_key(): assert result == "" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_ca_bundle_configmap_not_found(): """Test get_ca_bundle returns empty string when ConfigMap doesn't exist""" api = ClientsV1Alpha1Api(namespace="test-namespace") @@ -306,7 +306,7 @@ async def test_get_ca_bundle_configmap_not_found(): assert result == "" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_ca_bundle_other_api_error(): """Test get_ca_bundle raises exception for non-404 errors""" api = ClientsV1Alpha1Api(namespace="test-namespace") @@ -323,7 +323,7 @@ async def test_get_ca_bundle_other_api_error(): assert exc_info.value.status == 403 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_client_config_includes_ca_bundle(): """Test get_client_config includes CA bundle from ConfigMap""" api = ClientsV1Alpha1Api(namespace="test-namespace") @@ -368,7 +368,7 @@ async def test_get_client_config_includes_ca_bundle(): assert config.token == token -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_client_config_without_ca_bundle(): """Test get_client_config works when CA ConfigMap doesn't exist""" api = ClientsV1Alpha1Api(namespace="test-namespace") diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common.py index 9ba5ef0f8..c450b0154 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common.py @@ -1,9 +1,11 @@ """Common utilities and types for cluster operations.""" -import asyncio import os +from subprocess import PIPE from typing import Literal, Optional +import anyio + from ..exceptions import ClusterTypeValidationError ClusterType = Literal["kind"] | Literal["minikube"] | Literal["k3s"] @@ -72,16 +74,12 @@ async def run_command(cmd: list[str]) -> tuple[int, str, str]: raise ValueError("Command list cannot be empty") try: - process = await asyncio.create_subprocess_exec( - *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) - stdout, stderr = await process.communicate() + result = await anyio.run_process(cmd, stdout=PIPE, stderr=PIPE, check=False) - # Use safe decoding to avoid UnicodeDecodeError - stdout_str = stdout.decode(errors="replace").strip() - stderr_str = stderr.decode(errors="replace").strip() + stdout_str = result.stdout.decode(errors="replace").strip() + stderr_str = result.stderr.decode(errors="replace").strip() - return process.returncode, stdout_str, stderr_str + return result.returncode, stdout_str, stderr_str except builtins.FileNotFoundError as e: raise RuntimeError(f"Command not found: {cmd[0]}") from e except PermissionError as e: @@ -99,7 +97,7 @@ async def run_command_with_output(cmd: list[str]) -> int: raise ValueError("Command list cannot be empty") try: - process = await asyncio.create_subprocess_exec(*cmd) + process = await anyio.open_process(cmd) return await process.wait() except builtins.FileNotFoundError as e: raise RuntimeError(f"Command not found: {cmd[0]}") from e diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common_test.py index 2ee4cbb9e..346719dab 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/common_test.py @@ -1,8 +1,8 @@ """Tests for common cluster utilities and types.""" -import asyncio import os import tempfile +from subprocess import PIPE from unittest.mock import AsyncMock, patch import pytest @@ -185,30 +185,36 @@ def test_validate_cluster_name_numeric(self): class TestRunCommand: """Test run_command function.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_run_command_success(self): - with patch("asyncio.create_subprocess_exec") as mock_subprocess: - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"output\n", b"") - mock_process.returncode = 0 - mock_subprocess.return_value = mock_process + from subprocess import CompletedProcess + + with patch("jumpstarter_kubernetes.cluster.common.anyio.run_process", new_callable=AsyncMock) as mock_run: + mock_run.return_value = CompletedProcess( + args=["echo", "test"], + returncode=0, + stdout=b"output\n", + stderr=b"", + ) returncode, stdout, stderr = await run_command(["echo", "test"]) assert returncode == 0 assert stdout == "output" assert stderr == "" - mock_subprocess.assert_called_once_with( - "echo", "test", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE - ) + mock_run.assert_called_once_with(["echo", "test"], stdout=PIPE, stderr=PIPE, check=False) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_run_command_failure(self): - with patch("asyncio.create_subprocess_exec") as mock_subprocess: - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"", b"error message\n") - mock_process.returncode = 1 - mock_subprocess.return_value = mock_process + from subprocess import CompletedProcess + + with patch("jumpstarter_kubernetes.cluster.common.anyio.run_process", new_callable=AsyncMock) as mock_run: + mock_run.return_value = CompletedProcess( + args=["false"], + returncode=1, + stdout=b"", + stderr=b"error message\n", + ) returncode, stdout, stderr = await run_command(["false"]) @@ -216,15 +222,21 @@ async def test_run_command_failure(self): assert stdout == "" assert stderr == "error message" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_run_command_not_found(self): - with patch("asyncio.create_subprocess_exec", side_effect=FileNotFoundError("command not found")): + with patch( + "jumpstarter_kubernetes.cluster.common.anyio.run_process", + new_callable=AsyncMock, + side_effect=FileNotFoundError("command not found"), + ): with pytest.raises(RuntimeError, match="Command not found: nonexistent"): await run_command(["nonexistent"]) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_run_command_with_output_success(self): - with patch("asyncio.create_subprocess_exec") as mock_subprocess: + with patch( + "jumpstarter_kubernetes.cluster.common.anyio.open_process", new_callable=AsyncMock + ) as mock_subprocess: mock_process = AsyncMock() mock_process.wait.return_value = 0 mock_subprocess.return_value = mock_process @@ -232,17 +244,22 @@ async def test_run_command_with_output_success(self): returncode = await run_command_with_output(["echo", "test"]) assert returncode == 0 - mock_subprocess.assert_called_once_with("echo", "test") + mock_subprocess.assert_called_once_with(["echo", "test"]) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_run_command_with_output_not_found(self): - with patch("asyncio.create_subprocess_exec", side_effect=FileNotFoundError("command not found")): + with patch( + "jumpstarter_kubernetes.cluster.common.anyio.open_process", + side_effect=FileNotFoundError("command not found"), + ): with pytest.raises(RuntimeError, match="Command not found: nonexistent"): await run_command_with_output(["nonexistent"]) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_run_command_with_output_failure(self): - with patch("asyncio.create_subprocess_exec") as mock_subprocess: + with patch( + "jumpstarter_kubernetes.cluster.common.anyio.open_process", new_callable=AsyncMock + ) as mock_subprocess: mock_process = AsyncMock() mock_process.wait.return_value = 1 mock_subprocess.return_value = mock_process diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/detection_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/detection_test.py index eea8a78ba..32640010b 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/detection_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/detection_test.py @@ -56,7 +56,7 @@ def test_detect_container_runtime_docker_preferred(self, mock_which): class TestDetectKindProvider: """Test Kind provider detection.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.detect_container_runtime") @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_kind_provider_control_plane(self, mock_run_command, mock_detect_runtime): @@ -69,7 +69,7 @@ async def test_detect_kind_provider_control_plane(self, mock_run_command, mock_d assert node_name == "test-cluster-control-plane" mock_run_command.assert_called_once_with(["docker", "inspect", "test-cluster-control-plane"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.detect_container_runtime") @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_kind_provider_kind_prefix(self, mock_run_command, mock_detect_runtime): @@ -83,7 +83,7 @@ async def test_detect_kind_provider_kind_prefix(self, mock_run_command, mock_det assert node_name == "kind-test-cluster-control-plane" assert mock_run_command.call_count == 2 - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.detect_container_runtime") @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_kind_provider_default_cluster(self, mock_run_command, mock_detect_runtime): @@ -96,7 +96,7 @@ async def test_detect_kind_provider_default_cluster(self, mock_run_command, mock assert node_name == "kind-control-plane" mock_run_command.assert_called_once_with(["docker", "inspect", "kind-control-plane"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.detect_container_runtime") @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_kind_provider_fallback(self, mock_run_command, mock_detect_runtime): @@ -108,7 +108,7 @@ async def test_detect_kind_provider_fallback(self, mock_run_command, mock_detect assert runtime == "podman" assert node_name == "test-cluster-control-plane" # Fallback - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.detect_container_runtime") @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_kind_provider_runtime_error(self, mock_run_command, mock_detect_runtime): @@ -124,7 +124,7 @@ async def test_detect_kind_provider_runtime_error(self, mock_run_command, mock_d class TestDetectExistingClusterType: """Test detection of existing cluster types.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.kind_installed") @patch("jumpstarter_kubernetes.cluster.detection.minikube_installed") @patch("jumpstarter_kubernetes.cluster.detection.kind_cluster_exists") @@ -141,7 +141,7 @@ async def test_detect_existing_cluster_type_kind_only( assert result == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.kind_installed") @patch("jumpstarter_kubernetes.cluster.detection.minikube_installed") @patch("jumpstarter_kubernetes.cluster.detection.kind_cluster_exists") @@ -158,7 +158,7 @@ async def test_detect_existing_cluster_type_minikube_only( assert result == "minikube" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.kind_installed") @patch("jumpstarter_kubernetes.cluster.detection.minikube_installed") @patch("jumpstarter_kubernetes.cluster.detection.kind_cluster_exists") @@ -179,7 +179,7 @@ async def test_detect_existing_cluster_type_both_exist( ): await detect_existing_cluster_type("test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.kind_installed") @patch("jumpstarter_kubernetes.cluster.detection.minikube_installed") @patch("jumpstarter_kubernetes.cluster.detection.kind_cluster_exists") @@ -196,7 +196,7 @@ async def test_detect_existing_cluster_type_none_exist( assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.kind_installed") @patch("jumpstarter_kubernetes.cluster.detection.minikube_installed") async def test_detect_existing_cluster_type_kind_not_installed(self, mock_minikube_installed, mock_kind_installed): @@ -207,7 +207,7 @@ async def test_detect_existing_cluster_type_kind_not_installed(self, mock_miniku assert result is None - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.kind_installed") @patch("jumpstarter_kubernetes.cluster.detection.minikube_installed") @patch("jumpstarter_kubernetes.cluster.detection.kind_cluster_exists") @@ -274,37 +274,37 @@ def test_auto_detect_cluster_type_none_available(self, mock_minikube_installed, class TestDetectClusterType: """Test cluster type detection from context and server URL.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_kind_context_prefix(self): result = await detect_cluster_type("kind-test-cluster", "https://127.0.0.1:6443") assert result == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_kind_context_name(self): result = await detect_cluster_type("kind", "https://127.0.0.1:6443") assert result == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_minikube_context(self): result = await detect_cluster_type("minikube", "https://192.168.49.2:8443") assert result == "minikube" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_localhost(self): result = await detect_cluster_type("local-cluster", "https://localhost:6443") assert result == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_127_0_0_1(self): result = await detect_cluster_type("local-cluster", "https://127.0.0.1:6443") assert result == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_0_0_0_0(self): result = await detect_cluster_type("local-cluster", "https://0.0.0.0:6443") assert result == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_cluster_type_minikube_ip_range_192(self, mock_run_command): mock_run_command.return_value = (0, '{"valid": [{"Name": "test"}]}', "") @@ -313,7 +313,7 @@ async def test_detect_cluster_type_minikube_ip_range_192(self, mock_run_command) assert result == "minikube" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_cluster_type_minikube_ip_range_172(self, mock_run_command): mock_run_command.return_value = (0, '{"valid": [{"Name": "test"}]}', "") @@ -322,7 +322,7 @@ async def test_detect_cluster_type_minikube_ip_range_172(self, mock_run_command) assert result == "minikube" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_cluster_type_minikube_ip_no_profiles(self, mock_run_command): mock_run_command.return_value = (1, "", "error") @@ -331,7 +331,7 @@ async def test_detect_cluster_type_minikube_ip_no_profiles(self, mock_run_comman assert result == "remote" # Falls back to remote if no minikube profiles - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_cluster_type_minikube_invalid_json(self, mock_run_command): mock_run_command.return_value = (0, "invalid json", "") @@ -340,7 +340,7 @@ async def test_detect_cluster_type_minikube_invalid_json(self, mock_run_command) assert result == "remote" # Falls back to remote if JSON parsing fails - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.detection.run_command") async def test_detect_cluster_type_minikube_runtime_error(self, mock_run_command): mock_run_command.side_effect = RuntimeError("Command failed") @@ -349,12 +349,12 @@ async def test_detect_cluster_type_minikube_runtime_error(self, mock_run_command assert result == "remote" # Falls back to remote if command fails - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_remote(self): result = await detect_cluster_type("production-cluster", "https://k8s.example.com:443") assert result == "remote" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_detect_cluster_type_custom_minikube_binary(self): result = await detect_cluster_type("test-cluster", "https://example.com", minikube="custom-minikube") assert result == "remote" diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/endpoints_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/endpoints_test.py index 91f003574..42e27b858 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/endpoints_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/endpoints_test.py @@ -11,7 +11,7 @@ class TestGetIpGeneric: """Test generic IP address retrieval.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.minikube_installed") @patch("jumpstarter_kubernetes.cluster.endpoints.get_minikube_ip") async def test_get_ip_generic_minikube_success(self, mock_get_minikube_ip, mock_minikube_installed): @@ -23,7 +23,7 @@ async def test_get_ip_generic_minikube_success(self, mock_get_minikube_ip, mock_ assert result == "192.168.49.2" mock_get_minikube_ip.assert_called_once_with("test-cluster", "minikube") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.minikube_installed") async def test_get_ip_generic_minikube_not_installed(self, mock_minikube_installed): from jumpstarter_kubernetes.exceptions import ToolNotInstalledError @@ -33,7 +33,7 @@ async def test_get_ip_generic_minikube_not_installed(self, mock_minikube_install with pytest.raises(ToolNotInstalledError, match="minikube is not installed"): await get_ip_generic("minikube", "minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.minikube_installed") @patch("jumpstarter_kubernetes.cluster.endpoints.get_minikube_ip") async def test_get_ip_generic_minikube_ip_error(self, mock_get_minikube_ip, mock_minikube_installed): @@ -44,7 +44,7 @@ async def test_get_ip_generic_minikube_ip_error(self, mock_get_minikube_ip, mock with pytest.raises(EndpointConfigurationError, match="Could not determine Minikube IP address"): await get_ip_generic("minikube", "minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_address") async def test_get_ip_generic_kind_success(self, mock_get_ip_address): mock_get_ip_address.return_value = "10.0.0.100" @@ -53,7 +53,7 @@ async def test_get_ip_generic_kind_success(self, mock_get_ip_address): assert result == "10.0.0.100" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_address") async def test_get_ip_generic_kind_zero_ip(self, mock_get_ip_address): @@ -62,7 +62,7 @@ async def test_get_ip_generic_kind_zero_ip(self, mock_get_ip_address): with pytest.raises(EndpointConfigurationError, match="Could not determine IP address"): await get_ip_generic("kind", "minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_address") async def test_get_ip_generic_none_cluster_type(self, mock_get_ip_address): mock_get_ip_address.return_value = "192.168.1.100" @@ -71,7 +71,7 @@ async def test_get_ip_generic_none_cluster_type(self, mock_get_ip_address): assert result == "192.168.1.100" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_address") async def test_get_ip_generic_other_cluster_type(self, mock_get_ip_address): mock_get_ip_address.return_value = "172.16.0.50" @@ -84,7 +84,7 @@ async def test_get_ip_generic_other_cluster_type(self, mock_get_ip_address): class TestConfigureEndpoints: """Test endpoint configuration.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_all_provided(self, mock_get_ip_generic): # When all parameters are provided, get_ip_generic should not be called @@ -105,7 +105,7 @@ async def test_configure_endpoints_all_provided(self, mock_get_ip_generic): assert router_endpoint == "router.test.example.com:9001" mock_get_ip_generic.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_no_ip_provided(self, mock_get_ip_generic): mock_get_ip_generic.return_value = "192.168.49.2" @@ -127,7 +127,7 @@ async def test_configure_endpoints_no_ip_provided(self, mock_get_ip_generic): assert router_endpoint == "router.test.example.com:9001" mock_get_ip_generic.assert_called_once_with("minikube", "minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_no_basedomain_provided(self, mock_get_ip_generic): mock_get_ip_generic.return_value = "10.0.0.100" @@ -148,7 +148,7 @@ async def test_configure_endpoints_no_basedomain_provided(self, mock_get_ip_gene assert grpc_endpoint == "grpc.test.example.com:9000" assert router_endpoint == "router.test.example.com:9001" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_no_grpc_endpoint_provided(self, mock_get_ip_generic): mock_get_ip_generic.return_value = "10.0.0.100" @@ -169,7 +169,7 @@ async def test_configure_endpoints_no_grpc_endpoint_provided(self, mock_get_ip_g assert grpc_endpoint == "grpc.jumpstarter.10.0.0.100.nip.io:8082" assert router_endpoint == "router.test.example.com:9001" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_no_router_endpoint_provided(self, mock_get_ip_generic): mock_get_ip_generic.return_value = "10.0.0.100" @@ -190,7 +190,7 @@ async def test_configure_endpoints_no_router_endpoint_provided(self, mock_get_ip assert grpc_endpoint == "grpc.jumpstarter.10.0.0.100.nip.io:8082" assert router_endpoint == "router.jumpstarter.10.0.0.100.nip.io:8083" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_all_defaults(self, mock_get_ip_generic): mock_get_ip_generic.return_value = "192.168.1.50" @@ -212,7 +212,7 @@ async def test_configure_endpoints_all_defaults(self, mock_get_ip_generic): assert router_endpoint == "router.jumpstarter.192.168.1.50.nip.io:8083" mock_get_ip_generic.assert_called_once_with("minikube", "minikube", "my-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_custom_basedomain_with_defaults(self, mock_get_ip_generic): mock_get_ip_generic.return_value = "172.16.0.1" @@ -233,7 +233,7 @@ async def test_configure_endpoints_custom_basedomain_with_defaults(self, mock_ge assert grpc_endpoint == "grpc.custom.domain.io:8082" assert router_endpoint == "router.custom.domain.io:8083" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_ip_provided_no_auto_detection(self, mock_get_ip_generic): result = await configure_endpoints( @@ -253,7 +253,7 @@ async def test_configure_endpoints_ip_provided_no_auto_detection(self, mock_get_ assert router_endpoint == "router.jumpstarter.192.168.100.50.nip.io:8083" mock_get_ip_generic.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.endpoints.get_ip_generic") async def test_configure_endpoints_ip_detection_error_propagates(self, mock_get_ip_generic): mock_get_ip_generic.side_effect = EndpointConfigurationError("IP detection failed") diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kind_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kind_test.py index a26df7102..784db9ec0 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kind_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kind_test.py @@ -38,7 +38,7 @@ def test_kind_installed_custom_binary(self, mock_which): class TestKindClusterExists: """Test Kind cluster existence checking.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.run_command") async def test_kind_cluster_exists_true(self, mock_run_command, mock_kind_installed): @@ -50,7 +50,7 @@ async def test_kind_cluster_exists_true(self, mock_run_command, mock_kind_instal assert result is True mock_run_command.assert_called_once_with(["kind", "get", "kubeconfig", "--name", "test-cluster"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.run_command") async def test_kind_cluster_exists_false(self, mock_run_command, mock_kind_installed): @@ -61,7 +61,7 @@ async def test_kind_cluster_exists_false(self, mock_run_command, mock_kind_insta assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") async def test_kind_cluster_exists_kind_not_installed(self, mock_kind_installed): mock_kind_installed.return_value = False @@ -70,7 +70,7 @@ async def test_kind_cluster_exists_kind_not_installed(self, mock_kind_installed) assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.run_command") async def test_kind_cluster_exists_runtime_error(self, mock_run_command, mock_kind_installed): @@ -81,7 +81,7 @@ async def test_kind_cluster_exists_runtime_error(self, mock_run_command, mock_ki assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.run_command") async def test_kind_cluster_exists_custom_binary(self, mock_run_command, mock_kind_installed): @@ -97,7 +97,7 @@ async def test_kind_cluster_exists_custom_binary(self, mock_run_command, mock_ki class TestCreateKindCluster: """Test Kind cluster creation.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") @@ -117,7 +117,7 @@ async def test_create_kind_cluster_success(self, mock_run_command, mock_cluster_ assert "--name" in args assert "test-cluster" in args - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") async def test_create_kind_cluster_not_installed(self, mock_kind_installed): mock_kind_installed.return_value = False @@ -125,7 +125,7 @@ async def test_create_kind_cluster_not_installed(self, mock_kind_installed): with pytest.raises(RuntimeError, match="kind is not installed"): await create_kind_cluster("kind", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") async def test_create_kind_cluster_already_exists(self, mock_cluster_exists, mock_kind_installed): @@ -138,7 +138,7 @@ async def test_create_kind_cluster_already_exists(self, mock_cluster_exists, moc assert exc_info.value.cluster_name == "test-cluster" assert exc_info.value.cluster_type == "kind" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.delete_kind_cluster") @@ -156,7 +156,7 @@ async def test_create_kind_cluster_force_recreate( assert result is True mock_delete.assert_called_once_with("kind", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") @@ -174,7 +174,7 @@ async def test_create_kind_cluster_with_extra_args( assert "--verbosity=1" in args - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") @@ -188,7 +188,7 @@ async def test_create_kind_cluster_command_failure( with pytest.raises(RuntimeError, match="Failed to create Kind cluster 'test-cluster'"): await create_kind_cluster("kind", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") @@ -209,7 +209,7 @@ async def test_create_kind_cluster_custom_binary( class TestDeleteKindCluster: """Test Kind cluster deletion.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") @@ -223,7 +223,7 @@ async def test_delete_kind_cluster_success(self, mock_run_command, mock_cluster_ assert result is True mock_run_command.assert_called_once_with(["kind", "delete", "cluster", "--name", "test-cluster"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") async def test_delete_kind_cluster_not_installed(self, mock_kind_installed): mock_kind_installed.return_value = False @@ -231,7 +231,7 @@ async def test_delete_kind_cluster_not_installed(self, mock_kind_installed): with pytest.raises(RuntimeError, match="kind is not installed"): await delete_kind_cluster("kind", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") async def test_delete_kind_cluster_already_deleted(self, mock_cluster_exists, mock_kind_installed): @@ -242,7 +242,7 @@ async def test_delete_kind_cluster_already_deleted(self, mock_cluster_exists, mo assert result is True # Already deleted, consider successful - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") @@ -256,7 +256,7 @@ async def test_delete_kind_cluster_command_failure( with pytest.raises(RuntimeError, match="Failed to delete Kind cluster 'test-cluster'"): await delete_kind_cluster("kind", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kind.kind_installed") @patch("jumpstarter_kubernetes.cluster.kind.kind_cluster_exists") @patch("jumpstarter_kubernetes.cluster.kind.run_command_with_output") diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kubectl_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kubectl_test.py index f7d9c5ce8..86fa7f130 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kubectl_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/kubectl_test.py @@ -25,7 +25,7 @@ class TestCheckKubernetesAccess: """Test Kubernetes cluster access checking.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_kubernetes_access_success(self, mock_run_command): mock_run_command.return_value = (0, "cluster info", "") @@ -35,7 +35,7 @@ async def test_check_kubernetes_access_success(self, mock_run_command): assert result is True mock_run_command.assert_called_once_with(["kubectl", "cluster-info", "--request-timeout=5s"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_kubernetes_access_with_context(self, mock_run_command): mock_run_command.return_value = (0, "cluster info", "") @@ -47,7 +47,7 @@ async def test_check_kubernetes_access_with_context(self, mock_run_command): ["kubectl", "--context", "test-context", "cluster-info", "--request-timeout=5s"] ) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_kubernetes_access_custom_kubectl(self, mock_run_command): mock_run_command.return_value = (0, "cluster info", "") @@ -57,7 +57,7 @@ async def test_check_kubernetes_access_custom_kubectl(self, mock_run_command): assert result is True mock_run_command.assert_called_once_with(["custom-kubectl", "cluster-info", "--request-timeout=5s"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_kubernetes_access_failure(self, mock_run_command): mock_run_command.return_value = (1, "", "connection refused") @@ -66,7 +66,7 @@ async def test_check_kubernetes_access_failure(self, mock_run_command): assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_kubernetes_access_runtime_error(self, mock_run_command): mock_run_command.side_effect = RuntimeError("Command failed") @@ -79,7 +79,7 @@ async def test_check_kubernetes_access_runtime_error(self, mock_run_command): class TestGetKubectlContexts: """Test kubectl context retrieval.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_success(self, mock_run_command): kubectl_config = { @@ -115,7 +115,7 @@ async def test_get_kubectl_contexts_success(self, mock_run_command): "current": False, } - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_with_namespace(self, mock_run_command): kubectl_config = { @@ -135,7 +135,7 @@ async def test_get_kubectl_contexts_with_namespace(self, mock_run_command): assert len(result) == 1 assert result[0]["namespace"] == "custom-ns" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_no_current_context(self, mock_run_command): kubectl_config = { @@ -149,7 +149,7 @@ async def test_get_kubectl_contexts_no_current_context(self, mock_run_command): assert len(result) == 1 assert result[0]["current"] is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_missing_cluster(self, mock_run_command): kubectl_config = { @@ -164,7 +164,7 @@ async def test_get_kubectl_contexts_missing_cluster(self, mock_run_command): assert len(result) == 1 assert result[0]["server"] == "" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_command_failure(self, mock_run_command): from jumpstarter_kubernetes.exceptions import KubeconfigError @@ -174,7 +174,7 @@ async def test_get_kubectl_contexts_command_failure(self, mock_run_command): with pytest.raises(KubeconfigError, match="Failed to get kubectl config: permission denied"): await get_kubectl_contexts() - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_invalid_json(self, mock_run_command): from jumpstarter_kubernetes.exceptions import KubeconfigError @@ -184,7 +184,7 @@ async def test_get_kubectl_contexts_invalid_json(self, mock_run_command): with pytest.raises(KubeconfigError, match="Failed to parse kubectl config"): await get_kubectl_contexts() - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_current_field_is_bool(self, mock_run_command): kubectl_config = { @@ -204,7 +204,7 @@ async def test_get_kubectl_contexts_current_field_is_bool(self, mock_run_command assert isinstance(result[1]["current"], bool) assert result[1]["current"] is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_custom_kubectl(self, mock_run_command): kubectl_config = {"contexts": [], "clusters": []} @@ -214,7 +214,7 @@ async def test_get_kubectl_contexts_custom_kubectl(self, mock_run_command): mock_run_command.assert_called_once_with(["custom-kubectl", "config", "view", "-o", "json"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_propagates_programming_errors(self, mock_run_command): mock_run_command.return_value = (0, '{"contexts": [], "clusters": []}', "") @@ -222,7 +222,7 @@ async def test_get_kubectl_contexts_propagates_programming_errors(self, mock_run with pytest.raises(TypeError, match="unexpected type"): await get_kubectl_contexts() - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_kubectl_contexts_has_all_typed_keys(self, mock_run_command): kubectl_config = { @@ -241,7 +241,7 @@ async def test_get_kubectl_contexts_has_all_typed_keys(self, mock_run_command): class TestCheckCrInstances: """Test CR instance detection for Jumpstarter installation.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_found_with_namespace(self, mock_run_command): cr_response = {"items": [{"metadata": {"name": "jumpstarter", "namespace": "custom-ns"}}]} @@ -252,7 +252,7 @@ async def test_cr_instances_found_with_namespace(self, mock_run_command): assert result == {"installed": True, "namespace": "custom-ns", "status": "installed"} assert set(result.keys()) == set(CrInstanceSuccess.__annotations__.keys()) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_extracts_namespace_from_metadata(self, mock_run_command): cr_response = {"items": [{"metadata": {"name": "jumpstarter", "namespace": "from-cr"}}]} @@ -262,7 +262,7 @@ async def test_cr_instances_extracts_namespace_from_metadata(self, mock_run_comm assert result == {"installed": True, "namespace": "from-cr", "status": "installed"} - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_falls_back_to_parameter_namespace(self, mock_run_command): cr_response = {"items": [{"metadata": {"name": "jumpstarter"}}]} @@ -272,7 +272,7 @@ async def test_cr_instances_falls_back_to_parameter_namespace(self, mock_run_com assert result == {"installed": True, "namespace": "param-ns", "status": "installed"} - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_unknown_when_no_namespace_available(self, mock_run_command): cr_response = {"items": [{"metadata": {"name": "jumpstarter"}}]} @@ -282,7 +282,7 @@ async def test_cr_instances_unknown_when_no_namespace_available(self, mock_run_c assert result == {"installed": True, "namespace": "unknown", "status": "installed"} - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_empty_items(self, mock_run_command): cr_response = {"items": []} @@ -293,7 +293,7 @@ async def test_cr_instances_empty_items(self, mock_run_command): assert result == {"installed": False} assert set(result.keys()) == set(CrInstanceNotFound.__annotations__.keys()) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_missing_items_key(self, mock_run_command): cr_response = {"kind": "JumpstarterList"} @@ -304,7 +304,7 @@ async def test_cr_instances_missing_items_key(self, mock_run_command): assert result == {"installed": False} assert set(result.keys()) == set(CrInstanceNotFound.__annotations__.keys()) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_nonzero_return_code(self, mock_run_command): mock_run_command.return_value = (1, "", "forbidden") @@ -317,7 +317,7 @@ async def test_cr_instances_nonzero_return_code(self, mock_run_command): assert "forbidden" in result["error"] assert result["installed"] is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_json_decode_error(self, mock_run_command): mock_run_command.return_value = (0, "not valid json", "") @@ -329,7 +329,7 @@ async def test_cr_instances_json_decode_error(self, mock_run_command): assert "CR instance check failed" in result["error"] assert result["installed"] is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_cr_instances_runtime_error(self, mock_run_command): mock_run_command.side_effect = RuntimeError("kubectl not found") @@ -387,7 +387,7 @@ def test_apply_cr_result_not_found_leaves_data_unchanged(self): class TestCheckJumpstarterInstallation: """Test Jumpstarter installation checking via CRD detection.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_crds_only(self, mock_run_command): crds_response = {"items": [{"metadata": {"name": "exporters.jumpstarter.dev"}}]} @@ -399,7 +399,7 @@ async def test_check_jumpstarter_installation_crds_only(self, mock_run_command): assert result.has_crds is True assert result.installed is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_with_cr_instances(self, mock_run_command): crds_response = {"items": [ @@ -419,7 +419,7 @@ async def test_check_jumpstarter_installation_with_cr_instances(self, mock_run_c assert result.installed is True assert result.status == "installed" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_no_crds(self, mock_run_command): mock_run_command.return_value = (0, '{"items": []}', "") @@ -429,7 +429,7 @@ async def test_check_jumpstarter_installation_no_crds(self, mock_run_command): assert result.installed is False assert result.has_crds is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_command_failure(self, mock_run_command): mock_run_command.side_effect = RuntimeError("kubectl not found") @@ -441,7 +441,7 @@ async def test_check_jumpstarter_installation_command_failure(self, mock_run_com assert result.error is not None assert "kubectl not found" in result.error - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_nonzero_exit(self, mock_run_command): mock_run_command.return_value = (1, "", "forbidden") @@ -453,7 +453,7 @@ async def test_check_jumpstarter_installation_nonzero_exit(self, mock_run_comman assert result.error is not None assert "forbidden" in result.error - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_custom_namespace(self, mock_run_command): crds_response = {"items": [ @@ -472,7 +472,7 @@ async def test_check_jumpstarter_installation_custom_namespace(self, mock_run_co assert result.installed is True assert result.namespace == "custom-ns" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_json_decode_error(self, mock_run_command): mock_run_command.return_value = (0, "not valid json at all", "") @@ -483,7 +483,7 @@ async def test_check_jumpstarter_installation_json_decode_error(self, mock_run_c assert result.error is not None assert "Failed to parse output" in result.error - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_stdout_without_json_prefix(self, mock_run_command): crds_json = json.dumps({"items": [{"metadata": {"name": "exporters.jumpstarter.dev"}}]}) @@ -493,7 +493,7 @@ async def test_check_jumpstarter_installation_stdout_without_json_prefix(self, m assert result.has_crds is True - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_stdout_with_warning_prefix(self, mock_run_command): crds_json = json.dumps({"items": [{"metadata": {"name": "exporters.jumpstarter.dev"}}]}) @@ -504,7 +504,7 @@ async def test_check_jumpstarter_installation_stdout_with_warning_prefix(self, m assert result.has_crds is True - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_cr_check_empty_items(self, mock_run_command): crds_response = {"items": [ @@ -523,7 +523,7 @@ async def test_check_jumpstarter_installation_cr_check_empty_items(self, mock_ru assert result.installed is False assert result.error is None - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_check_jumpstarter_installation_cr_check_error(self, mock_run_command): crds_response = {"items": [ @@ -546,7 +546,7 @@ async def test_check_jumpstarter_installation_cr_check_error(self, mock_run_comm class TestGetClusterInfo: """Test cluster info retrieval.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") @patch("jumpstarter_kubernetes.cluster.kubectl.check_jumpstarter_installation") @@ -577,7 +577,7 @@ async def test_get_cluster_info_success(self, mock_check_jumpstarter, mock_run_c assert result.version == "v1.28.0" assert result.jumpstarter.installed is True - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_get_cluster_info_inaccessible(self, mock_get_contexts): # Mock get_kubectl_contexts to fail @@ -589,7 +589,7 @@ async def test_get_cluster_info_inaccessible(self, mock_get_contexts): assert "Failed to get cluster info:" in result.error assert "Failed to get kubectl config: connection refused" in result.error - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_get_cluster_info_invalid_json(self, mock_get_contexts): error_msg = "Failed to parse kubectl config: Expecting value: line 1 column 1 (char 0)" @@ -601,7 +601,7 @@ async def test_get_cluster_info_invalid_json(self, mock_get_contexts): assert "Failed to get cluster info" in result.error assert "Failed to parse kubectl config" in result.error - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_get_cluster_info_context_not_found(self, mock_get_contexts): mock_get_contexts.return_value = [ @@ -621,7 +621,7 @@ async def test_get_cluster_info_context_not_found(self, mock_get_contexts): assert result.accessible is False assert "not found" in result.error - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") @patch("jumpstarter_kubernetes.cluster.kubectl.check_jumpstarter_installation") @@ -647,7 +647,7 @@ async def test_get_cluster_info_inaccessible_cluster( assert result.jumpstarter.error == "Cluster not accessible" mock_check_jumpstarter.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") @patch("jumpstarter_kubernetes.cluster.kubectl.check_jumpstarter_installation") @@ -672,7 +672,7 @@ async def test_get_cluster_info_version_parse_failure( assert result.accessible is True assert result.version == "unknown" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.run_command") async def test_get_cluster_info_version_command_runtime_error( @@ -695,7 +695,7 @@ async def test_get_cluster_info_version_command_runtime_error( assert result.accessible is False assert result.version is None - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_get_cluster_info_propagates_programming_errors(self, mock_get_contexts): mock_get_contexts.side_effect = TypeError("unexpected type") @@ -707,7 +707,7 @@ async def test_get_cluster_info_propagates_programming_errors(self, mock_get_con class TestListClusters: """Test cluster listing functionality.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.get_cluster_info") async def test_list_clusters_success(self, mock_get_cluster_info, mock_get_contexts): @@ -740,7 +740,7 @@ async def test_list_clusters_success(self, mock_get_cluster_info, mock_get_conte assert len(result.items) == 1 assert result.items[0].name == "test-context" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_list_clusters_no_contexts(self, mock_get_contexts): mock_get_contexts.return_value = [] @@ -749,7 +749,7 @@ async def test_list_clusters_no_contexts(self, mock_get_contexts): assert len(result.items) == 0 - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_list_clusters_context_error(self, mock_get_contexts): mock_get_contexts.side_effect = JumpstarterKubernetesError("No kubeconfig found") @@ -759,7 +759,7 @@ async def test_list_clusters_context_error(self, mock_get_contexts): assert len(result.items) == 1 assert result.items[0].error == "Failed to list clusters: No kubeconfig found" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.get_cluster_info") async def test_list_clusters_custom_parameters(self, mock_get_cluster_info, mock_get_contexts): @@ -784,7 +784,7 @@ async def test_list_clusters_custom_parameters(self, mock_get_cluster_info, mock mock_get_contexts.assert_called_once_with("custom-kubectl") mock_get_cluster_info.assert_called_once_with("ctx", "custom-kubectl", "custom-minikube") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") @patch("jumpstarter_kubernetes.cluster.kubectl.get_cluster_info") async def test_list_clusters_with_type_filter(self, mock_get_cluster_info, mock_get_contexts): @@ -822,7 +822,7 @@ async def test_list_clusters_with_type_filter(self, mock_get_cluster_info, mock_ assert len(result.items) == 1 assert result.items[0].name == "kind-ctx" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.kubectl.get_kubectl_contexts") async def test_list_clusters_propagates_programming_errors(self, mock_get_contexts): mock_get_contexts.side_effect = TypeError("unexpected type") diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/minikube_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/minikube_test.py index ce624e242..1bf7d8c5f 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/minikube_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/minikube_test.py @@ -37,7 +37,7 @@ def test_minikube_installed_custom_binary(self, mock_which): class TestMinikubeClusterExists: """Test Minikube cluster existence checking.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.run_command") async def test_minikube_cluster_exists_true(self, mock_run_command, mock_minikube_installed): @@ -55,7 +55,7 @@ async def test_minikube_cluster_exists_true(self, mock_run_command, mock_minikub # Should call profile list first mock_run_command.assert_called_with(["minikube", "profile", "list", "-o", "json"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.run_command") async def test_minikube_cluster_exists_false(self, mock_run_command, mock_minikube_installed): @@ -66,7 +66,7 @@ async def test_minikube_cluster_exists_false(self, mock_run_command, mock_miniku assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") async def test_minikube_cluster_exists_minikube_not_installed(self, mock_minikube_installed): mock_minikube_installed.return_value = False @@ -75,7 +75,7 @@ async def test_minikube_cluster_exists_minikube_not_installed(self, mock_minikub assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.run_command") async def test_minikube_cluster_exists_runtime_error(self, mock_run_command, mock_minikube_installed): @@ -87,7 +87,7 @@ async def test_minikube_cluster_exists_runtime_error(self, mock_run_command, moc assert result is False - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.run_command") async def test_minikube_cluster_exists_stopped_cluster(self, mock_run_command, mock_minikube_installed): @@ -104,7 +104,7 @@ async def test_minikube_cluster_exists_stopped_cluster(self, mock_run_command, m assert result is True - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.run_command") async def test_minikube_cluster_exists_custom_binary(self, mock_run_command, mock_minikube_installed): @@ -128,7 +128,7 @@ async def test_minikube_cluster_exists_custom_binary(self, mock_run_command, moc class TestCreateMinikubeCluster: """Test Minikube cluster creation.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") @@ -150,7 +150,7 @@ async def test_create_minikube_cluster_success( assert "test-cluster" in args assert "--extra-config=apiserver.service-node-port-range=30000-32767" in args - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") async def test_create_minikube_cluster_not_installed(self, mock_minikube_installed): from jumpstarter_kubernetes.exceptions import ToolNotInstalledError @@ -160,7 +160,7 @@ async def test_create_minikube_cluster_not_installed(self, mock_minikube_install with pytest.raises(ToolNotInstalledError): await create_minikube_cluster("minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") async def test_create_minikube_cluster_already_exists(self, mock_cluster_exists, mock_minikube_installed): @@ -171,7 +171,7 @@ async def test_create_minikube_cluster_already_exists(self, mock_cluster_exists, result = await create_minikube_cluster("minikube", "test-cluster") assert result is True - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") @@ -188,7 +188,7 @@ async def test_create_minikube_cluster_with_extra_args( args = mock_run_command.call_args[0][0] assert "--memory=4096" in args - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") @@ -204,7 +204,7 @@ async def test_create_minikube_cluster_command_failure( with pytest.raises(ClusterOperationError): await create_minikube_cluster("minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") @@ -225,7 +225,7 @@ async def test_create_minikube_cluster_custom_binary( class TestDeleteMinikubeCluster: """Test Minikube cluster deletion.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") @@ -241,7 +241,7 @@ async def test_delete_minikube_cluster_success( assert result is True mock_run_command.assert_called_once_with(["minikube", "delete", "-p", "test-cluster"]) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") async def test_delete_minikube_cluster_not_installed(self, mock_minikube_installed): from jumpstarter_kubernetes.exceptions import ToolNotInstalledError @@ -251,7 +251,7 @@ async def test_delete_minikube_cluster_not_installed(self, mock_minikube_install with pytest.raises(ToolNotInstalledError): await delete_minikube_cluster("minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") async def test_delete_minikube_cluster_already_deleted(self, mock_cluster_exists, mock_minikube_installed): @@ -262,7 +262,7 @@ async def test_delete_minikube_cluster_already_deleted(self, mock_cluster_exists assert result is True # Already deleted, consider successful - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") @@ -278,7 +278,7 @@ async def test_delete_minikube_cluster_failure( with pytest.raises(ClusterOperationError): await delete_minikube_cluster("minikube", "test-cluster") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.minikube.minikube_installed") @patch("jumpstarter_kubernetes.cluster.minikube.minikube_cluster_exists") @patch("jumpstarter_kubernetes.cluster.minikube.run_command_with_output") diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operations_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operations_test.py index 9aba7ed03..18e4b638a 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operations_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operations_test.py @@ -15,7 +15,7 @@ class TestDeleteClusterByName: """Test cluster deletion by name.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.detect_existing_cluster_type") @patch("jumpstarter_kubernetes.cluster.operations.delete_kind_cluster_with_feedback") async def test_delete_cluster_by_name_kind(self, mock_delete_kind, mock_detect): @@ -27,7 +27,7 @@ async def test_delete_cluster_by_name_kind(self, mock_delete_kind, mock_detect): mock_detect.assert_called_once_with("test-cluster") mock_delete_kind.assert_called_once_with("kind", "test-cluster", ANY) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.detect_existing_cluster_type") @patch("jumpstarter_kubernetes.cluster.operations.delete_minikube_cluster_with_feedback") async def test_delete_cluster_by_name_minikube(self, mock_delete_minikube, mock_detect): @@ -39,7 +39,7 @@ async def test_delete_cluster_by_name_minikube(self, mock_delete_minikube, mock_ mock_detect.assert_called_once_with("test-cluster") mock_delete_minikube.assert_called_once_with("minikube", "test-cluster", ANY) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.detect_existing_cluster_type") async def test_delete_cluster_by_name_not_found(self, mock_detect): mock_detect.return_value = None @@ -47,7 +47,7 @@ async def test_delete_cluster_by_name_not_found(self, mock_detect): with pytest.raises(ClusterNotFoundError): await delete_cluster_by_name("test-cluster", force=True) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.detect_existing_cluster_type") @patch("jumpstarter_kubernetes.cluster.operations.kind_installed") @patch("jumpstarter_kubernetes.cluster.operations.kind_cluster_exists") @@ -66,7 +66,7 @@ async def test_delete_cluster_by_name_with_type( mock_cluster_exists.assert_called_once_with("kind", "test-cluster") mock_delete_kind.assert_called_once_with("kind", "test-cluster", ANY) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_delete_cluster_unsupported_type_explicit(self): """Test that explicitly specifying an unsupported cluster type raises ClusterTypeValidationError.""" with pytest.raises(ClusterTypeValidationError) as exc_info: @@ -76,7 +76,7 @@ async def test_delete_cluster_unsupported_type_explicit(self): assert "kind" in str(exc_info.value) assert "minikube" in str(exc_info.value) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.detect_existing_cluster_type") async def test_delete_cluster_unsupported_type_auto_detected(self, mock_detect): """Test that auto-detecting an unsupported cluster type raises ClusterTypeValidationError.""" @@ -93,7 +93,7 @@ async def test_delete_cluster_unsupported_type_auto_detected(self, mock_detect): class TestCreateClusterOnly: """Test cluster-only creation.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.create_cluster_and_install") async def test_create_cluster_only_kind(self, mock_create_and_install): mock_create_and_install.return_value = None @@ -104,7 +104,7 @@ async def test_create_cluster_only_kind(self, mock_create_and_install): "kind", False, "test-cluster", "", "", "kind", "minikube", None, install_jumpstarter=False, callback=None ) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.create_cluster_and_install") async def test_create_cluster_only_minikube(self, mock_create_and_install): mock_create_and_install.return_value = None @@ -128,7 +128,7 @@ async def test_create_cluster_only_minikube(self, mock_create_and_install): class TestCreateClusterAndInstall: """Test cluster creation with operator installation.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.create_kind_cluster_with_options") @patch("jumpstarter_kubernetes.cluster.operations.configure_endpoints") @patch("jumpstarter_kubernetes.cluster.operations.install_jumpstarter_operator") @@ -154,7 +154,7 @@ async def test_create_cluster_and_install_success( callback=ANY, ) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.create_kind_cluster_with_options") @patch("jumpstarter_kubernetes.cluster.operations.configure_endpoints") async def test_create_cluster_and_install_no_version(self, mock_configure, mock_create): @@ -166,7 +166,7 @@ async def test_create_cluster_and_install_no_version(self, mock_configure, mock_ with pytest.raises(ClusterOperationError): await create_cluster_and_install("kind", False, "test-cluster", "", "", "kind", "minikube") - @pytest.mark.asyncio + @pytest.mark.anyio async def test_create_cluster_and_install_unsupported_cluster_type(self): """Test that creating a cluster with an unsupported cluster type raises ClusterTypeValidationError.""" with pytest.raises(ClusterTypeValidationError) as exc_info: @@ -174,7 +174,7 @@ async def test_create_cluster_and_install_unsupported_cluster_type(self): assert "Unsupported cluster_type: remote" in str(exc_info.value) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.create_kind_cluster_with_options") @patch("jumpstarter_kubernetes.cluster.operations.configure_endpoints") @patch("jumpstarter_kubernetes.cluster.operations.install_jumpstarter_operator") @@ -195,7 +195,7 @@ async def test_create_cluster_and_install_force_recreate_confirmed( mock_create.assert_called_once() mock_install.assert_called_once() - @pytest.mark.asyncio + @pytest.mark.anyio async def test_create_cluster_and_install_force_recreate_cancelled(self): from jumpstarter_kubernetes.exceptions import ClusterOperationError @@ -212,7 +212,7 @@ def confirm(self, prompt): return False version="1.0.0", callback=RejectingCallback(), ) - @pytest.mark.asyncio + @pytest.mark.anyio @patch("jumpstarter_kubernetes.cluster.operations.create_kind_cluster_with_options") @patch("jumpstarter_kubernetes.cluster.operations.configure_endpoints") @patch("jumpstarter_kubernetes.cluster.operations.install_jumpstarter_operator") diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operator.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operator.py index 7987682a8..2f764c974 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operator.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/cluster/operator.py @@ -1,8 +1,10 @@ """Operator-based Jumpstarter installation.""" -import asyncio +from subprocess import PIPE from typing import Literal, Optional +import anyio + from ..callbacks import OutputCallback, SilentCallback from ..exceptions import ClusterOperationError from .common import GRPC_NODEPORT, LOGIN_NODEPORT, ROUTER_NODEPORT, run_command, run_command_with_output @@ -231,16 +233,11 @@ async def apply_jumpstarter_cr( returncode, ns_yaml, _ = await run_command(cmd) if returncode == 0: apply_cmd = _kubectl_base(kubeconfig, context) + ["apply", "-f", "-"] - process = await asyncio.create_subprocess_exec( - *apply_cmd, stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - _, ns_stderr = await process.communicate(input=ns_yaml.encode()) - if process.returncode != 0: + result = await anyio.run_process(apply_cmd, input=ns_yaml.encode(), stdout=PIPE, stderr=PIPE) + if result.returncode != 0: raise ClusterOperationError( "install", "jumpstarter", "operator", - Exception(f"Failed to create namespace {namespace}: {ns_stderr.decode(errors='replace')}"), + Exception(f"Failed to create namespace {namespace}: {result.stderr.decode(errors='replace')}"), ) # Build and apply the CR @@ -248,18 +245,12 @@ async def apply_jumpstarter_cr( callback.progress("Applying Jumpstarter CR...") apply_cmd = _kubectl_base(kubeconfig, context) + ["apply", "-f", "-"] - process = await asyncio.create_subprocess_exec( - *apply_cmd, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await process.communicate(input=cr_yaml.encode()) + result = await anyio.run_process(apply_cmd, input=cr_yaml.encode(), stdout=PIPE, stderr=PIPE, check=False) - if process.returncode != 0: + if result.returncode != 0: raise ClusterOperationError( "install", "jumpstarter", "operator", - Exception(f"Failed to apply Jumpstarter CR: {stderr.decode(errors='replace')}"), + Exception(f"Failed to apply Jumpstarter CR: {result.stderr.decode(errors='replace')}"), ) callback.success("Jumpstarter CR applied") @@ -291,7 +282,7 @@ async def wait_for_jumpstarter_ready( returncode, _, _ = await run_command(cmd) if returncode == 0: break - await asyncio.sleep(poll_interval) + await anyio.sleep(poll_interval) else: raise ClusterOperationError( "install", "jumpstarter", "operator", diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/controller_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/controller_test.py index 336dfe51a..5e6429aff 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/controller_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/controller_test.py @@ -11,7 +11,7 @@ class TestGetLatestCompatibleControllerVersion: """Test controller version resolution from Quay.io API.""" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_requests_correct_url(self, mock_session_class): tags_response = {"tags": [{"name": "v0.5.0"}]} @@ -40,7 +40,7 @@ def capture_get(url, **kwargs): assert captured_url == "https://quay.io/api/v1/repository/jumpstarter-dev/jumpstarter-operator/tag/" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_returns_compatible_version(self, mock_session_class): tags_response = { @@ -66,7 +66,7 @@ async def test_returns_compatible_version(self, mock_session_class): assert result == "v0.5.2" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_falls_back_to_latest_when_no_compatible(self, mock_session_class): tags_response = { @@ -90,7 +90,7 @@ async def test_falls_back_to_latest_when_no_compatible(self, mock_session_class) assert result == "v0.7.0" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_returns_latest_when_no_client_version(self, mock_session_class): tags_response = { @@ -114,7 +114,7 @@ async def test_returns_latest_when_no_client_version(self, mock_session_class): assert result == "v0.6.0" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_skips_invalid_semver_tags(self, mock_session_class): tags_response = { @@ -139,7 +139,7 @@ async def test_skips_invalid_semver_tags(self, mock_session_class): assert result == "v0.5.0" - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_raises_on_unexpected_response_format(self, mock_session_class): mock_response = AsyncMock() @@ -156,7 +156,7 @@ async def test_raises_on_unexpected_response_format(self, mock_session_class): with pytest.raises(JumpstarterKubernetesError, match="Unexpected response"): await get_latest_compatible_controller_version("v0.5.0") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_raises_on_no_valid_versions(self, mock_session_class): tags_response = {"tags": [{"name": "latest"}]} @@ -174,12 +174,12 @@ async def test_raises_on_no_valid_versions(self, mock_session_class): with pytest.raises(JumpstarterKubernetesError, match="No valid controller versions"): await get_latest_compatible_controller_version("v0.5.0") - @pytest.mark.asyncio + @pytest.mark.anyio async def test_raises_on_invalid_client_version(self): with pytest.raises(JumpstarterKubernetesError, match="Invalid client version"): await get_latest_compatible_controller_version("not-a-version") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_raises_on_fetch_failure(self, mock_session_class): mock_session = AsyncMock() @@ -191,7 +191,7 @@ async def test_raises_on_fetch_failure(self, mock_session_class): with pytest.raises(JumpstarterKubernetesError, match="Failed to fetch controller versions"): await get_latest_compatible_controller_version("v0.5.0") - @pytest.mark.asyncio + @pytest.mark.anyio @patch("aiohttp.ClientSession") async def test_skips_malformed_tag_entries(self, mock_session_class): tags_response = { diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters.py index d5b10e278..4bc87fc9e 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters.py @@ -1,7 +1,7 @@ -import asyncio import base64 from typing import Literal +import anyio from kubernetes_asyncio.client.models import V1ObjectMeta, V1ObjectReference from pydantic import Field @@ -179,7 +179,7 @@ async def create_exporter( if "credential" in updated_exporter["status"]: return V1Alpha1Exporter.from_dict(updated_exporter) count += 1 - await asyncio.sleep(CREATE_EXPORTER_DELAY) + await anyio.sleep(CREATE_EXPORTER_DELAY) raise Exception("Timeout waiting for exporter credentials") async def get_exporter_config(self, name: str) -> ExporterConfigV1Alpha1: diff --git a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters_test.py b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters_test.py index 7bb34c173..d7c5fd699 100644 --- a/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters_test.py +++ b/python/packages/jumpstarter-kubernetes/jumpstarter_kubernetes/exporters_test.py @@ -275,7 +275,7 @@ def test_exporter_list_rich_add_names(): # Tests for get_exporter_config with CA bundle -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_exporter_config_includes_ca_bundle(): """Test get_exporter_config includes CA bundle from ConfigMap""" api = ExportersV1Alpha1Api(namespace="test-namespace") @@ -321,7 +321,7 @@ async def test_get_exporter_config_includes_ca_bundle(): assert config.token == token -@pytest.mark.asyncio +@pytest.mark.anyio async def test_get_exporter_config_without_ca_bundle(): """Test get_exporter_config works when CA ConfigMap doesn't exist""" api = ExportersV1Alpha1Api(namespace="test-namespace") diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py index 5fefb2aa7..a80996ab4 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import json import logging import os @@ -496,7 +495,7 @@ async def run_server(): write_stream, mcp._mcp_server.create_initialization_options(), ) - except asyncio.CancelledError: + except anyio.get_cancelled_exc_class(): logger.info("MCP stdio session ended (cancelled)") except BaseException as exc: if isinstance(exc, ClosedResourceError): diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py index 48f239d5f..eeff533e0 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/server_test.py @@ -2,12 +2,11 @@ from __future__ import annotations -import asyncio -import asyncio.subprocess import logging import time from dataclasses import dataclass from datetime import datetime +from subprocess import DEVNULL from unittest.mock import AsyncMock, MagicMock, patch import click @@ -324,19 +323,24 @@ def manager_with_conn(self): manager._connections[conn.id] = conn return manager, conn.id - @pytest.mark.asyncio + @pytest.mark.anyio async def test_successful_command(self, manager_with_conn): + from subprocess import CompletedProcess + from jumpstarter_mcp.tools.commands import run_command manager, conn_id = manager_with_conn - mock_proc = AsyncMock() - mock_proc.communicate = AsyncMock(return_value=(b"hello\n", b"")) - mock_proc.returncode = 0 + mock_result = CompletedProcess( + args=["/usr/bin/j", "power", "on"], + returncode=0, + stdout=b"hello\n", + stderr=b"", + ) with ( patch("shutil.which", return_value="/usr/bin/j"), - patch("asyncio.create_subprocess_exec", return_value=mock_proc), + patch("jumpstarter_mcp.tools.commands.anyio.run_process", new_callable=AsyncMock, return_value=mock_result), ): result = await run_command(manager, conn_id, ["power", "on"]) @@ -344,37 +348,34 @@ async def test_successful_command(self, manager_with_conn): assert result["stdout"] == "hello\n" assert "timed_out" not in result - @pytest.mark.asyncio - async def test_timeout_captures_output(self, manager_with_conn): + @pytest.mark.anyio + async def test_timeout_discards_output(self, manager_with_conn): + """On timeout, output is discarded and timed_out flag is set. + + anyio.run_process is cancelled on timeout, so partial stdout/stderr + is not available. This is a design trade-off of using run_process + (which collects output atomically) vs open_process (which would + require manual stream management for partial capture). + """ from jumpstarter_mcp.tools.commands import run_command manager, conn_id = manager_with_conn - call_count = 0 + async def slow_run_process(*args, **kwargs): + import anyio - async def fake_communicate(): - nonlocal call_count - call_count += 1 # ty: ignore[unresolved-reference] - if call_count == 1: - await asyncio.sleep(999) - return (b"partial", b"err") - - mock_proc = AsyncMock() - mock_proc.communicate = fake_communicate - mock_proc.kill = lambda: None - mock_proc.returncode = -9 + await anyio.sleep(999) with ( patch("shutil.which", return_value="/usr/bin/j"), - patch("asyncio.create_subprocess_exec", return_value=mock_proc), + patch("jumpstarter_mcp.tools.commands.anyio.run_process", side_effect=slow_run_process), ): result = await run_command(manager, conn_id, ["serial", "pipe"], timeout_seconds=1) assert result["timed_out"] is True assert result["timeout_seconds"] == 1 - assert result["stdout"] == "partial" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_j_not_found(self, manager_with_conn): from jumpstarter_mcp.tools.commands import run_command @@ -386,25 +387,32 @@ async def test_j_not_found(self, manager_with_conn): assert "error" in result assert "not found" in result["error"] - @pytest.mark.asyncio + @pytest.mark.anyio async def test_subprocess_stdin_is_devnull(self, manager_with_conn): """Subprocess must not inherit MCP's stdin (would consume JSON-RPC input).""" + from subprocess import CompletedProcess + from jumpstarter_mcp.tools.commands import run_command manager, conn_id = manager_with_conn - mock_proc = AsyncMock() - mock_proc.communicate = AsyncMock(return_value=(b"ok\n", b"")) - mock_proc.returncode = 0 + mock_result = CompletedProcess( + args=["/usr/bin/j", "power", "on"], + returncode=0, + stdout=b"ok\n", + stderr=b"", + ) with ( patch("shutil.which", return_value="/usr/bin/j"), - patch("asyncio.create_subprocess_exec", return_value=mock_proc) as mock_exec, + patch( + "jumpstarter_mcp.tools.commands.anyio.run_process", new_callable=AsyncMock, return_value=mock_result + ) as mock_exec, ): await run_command(manager, conn_id, ["power", "on"]) _, kwargs = mock_exec.call_args - assert kwargs["stdin"] == asyncio.subprocess.DEVNULL + assert kwargs["stdin"] == DEVNULL # --------------------------------------------------------------------------- @@ -438,14 +446,14 @@ def _make_jwt_payload(exp: int | None = None, iss: str = "https://sso.example.co class TestEnsureFreshToken: - @pytest.mark.asyncio + @pytest.mark.anyio async def test_no_token_returns_config_unchanged(self): config = MagicMock() config.token = None result = await _ensure_fresh_token(config) assert result is config - @pytest.mark.asyncio + @pytest.mark.anyio async def test_valid_token_skips_refresh(self): future_exp = int(time.time()) + 3600 config = MagicMock() @@ -458,7 +466,7 @@ async def test_valid_token_skips_refresh(self): assert result is config mock_cls.save.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.anyio async def test_expired_token_no_refresh_token_skips(self): past_exp = int(time.time()) - 60 config = MagicMock() @@ -468,7 +476,7 @@ async def test_expired_token_no_refresh_token_skips(self): result = await _ensure_fresh_token(config) assert result is config - @pytest.mark.asyncio + @pytest.mark.anyio async def test_expired_token_refreshes_successfully(self): past_exp = int(time.time()) - 60 config = MagicMock() @@ -493,7 +501,7 @@ async def test_expired_token_refreshes_successfully(self): assert result.refresh_token == new_refresh mock_cls.save.assert_called_once_with(config) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_expired_token_refresh_updates_only_access_when_no_new_refresh(self): past_exp = int(time.time()) - 60 config = MagicMock() @@ -514,7 +522,7 @@ async def test_expired_token_refresh_updates_only_access_when_no_new_refresh(sel assert result.token == "new-access" assert result.refresh_token == "old-refresh" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_expired_token_refresh_failure_returns_config_unchanged(self): past_exp = int(time.time()) - 60 original_token = _make_jwt_payload(exp=past_exp) @@ -534,7 +542,7 @@ async def test_expired_token_refresh_failure_returns_config_unchanged(self): assert result is config mock_cls.save.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.anyio async def test_near_expiry_triggers_refresh(self): near_exp = int(time.time()) + TOKEN_REFRESH_THRESHOLD_SECONDS - 1 config = MagicMock() @@ -552,7 +560,7 @@ async def test_near_expiry_triggers_refresh(self): assert result.token == "refreshed" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_token_without_exp_claim_skips_refresh(self): config = MagicMock() config.token = _make_jwt_payload(exp=None) @@ -580,9 +588,7 @@ def test_adds_file_handler_to_root_logger(self, tmp_path): _setup_logging() new_file_handlers = [ - h - for h in root.handlers - if isinstance(h, logging.FileHandler) and h not in handlers_before + h for h in root.handlers if isinstance(h, logging.FileHandler) and h not in handlers_before ] assert len(new_file_handlers) == 1 assert "mcp-server.log" in new_file_handlers[0].baseFilename @@ -656,7 +662,7 @@ def test_stray_writes_do_not_reach_saved_stdout(self): # Apply the same redirect pattern as run_server(): sys.stdout.flush() - mcp_fd = os.dup(sys.stdout.fileno()) # save "real stdout" (pipe) + mcp_fd = os.dup(sys.stdout.fileno()) # save "real stdout" (pipe) os.dup2(sys.stderr.fileno(), sys.stdout.fileno()) # fd 1 -> stderr sys.stdout = sys.stderr diff --git a/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py b/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py index 8a8259da7..92f415e82 100644 --- a/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py +++ b/python/packages/jumpstarter-mcp/jumpstarter_mcp/tools/commands.py @@ -2,10 +2,9 @@ from __future__ import annotations -import asyncio -import asyncio.subprocess import logging import shutil +from subprocess import DEVNULL, PIPE import anyio import anyio.to_thread @@ -45,30 +44,28 @@ async def run_command( logger.info("Running command: j %s (timeout=%ds)", " ".join(command), timeout_seconds) try: - proc = await asyncio.create_subprocess_exec( - j_path, - *command, - stdin=asyncio.subprocess.DEVNULL, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - env=full_env, - ) - stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout_seconds) - logger.info("Command finished: j %s -> exit_code=%s", " ".join(command), proc.returncode) + with anyio.fail_after(timeout_seconds): + result = await anyio.run_process( + [j_path, *command], + stdin=DEVNULL, + stdout=PIPE, + stderr=PIPE, + env=full_env, + check=False, + ) + logger.info("Command finished: j %s -> exit_code=%s", " ".join(command), result.returncode) return { - "exit_code": proc.returncode, - "stdout": stdout.decode(errors="replace"), - "stderr": stderr.decode(errors="replace"), + "exit_code": result.returncode, + "stdout": result.stdout.decode(errors="replace"), + "stderr": result.stderr.decode(errors="replace"), "command": [j_path, *command], } except TimeoutError: - proc.kill() - stdout, stderr = await proc.communicate() logger.warning("Command timed out after %ds: j %s", timeout_seconds, " ".join(command)) return { - "exit_code": proc.returncode, - "stdout": stdout.decode(errors="replace"), - "stderr": stderr.decode(errors="replace"), + "exit_code": -1, + "stdout": "", + "stderr": "", "timed_out": True, "timeout_seconds": timeout_seconds, "command": [j_path, *command], diff --git a/python/packages/jumpstarter/jumpstarter/client/grpc_test.py b/python/packages/jumpstarter/jumpstarter/client/grpc_test.py index 20f4e904f..34402cdd5 100644 --- a/python/packages/jumpstarter/jumpstarter/client/grpc_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/grpc_test.py @@ -532,7 +532,7 @@ def test_rich_display_empty_tags(self): assert "TAGS" in columns -@pytest.mark.asyncio +@pytest.mark.anyio async def test_create_lease_sets_tags_on_protobuf(): from jumpstarter_protocol import client_pb2 diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 87a3f16be..71f13e0eb 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -1,9 +1,9 @@ -import asyncio import logging import sys from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, patch +import anyio import pytest from rich.console import Console @@ -174,7 +174,7 @@ def test_elapsed_time_formatting(self): assert "[dim](" in call_args assert "[/dim]" in call_args - @pytest.mark.asyncio + @pytest.mark.anyio async def test_integration_with_async_context(self): """Test integration with async context manager.""" with patch.object(LeaseAcquisitionSpinner, "_is_terminal_available", return_value=True): @@ -187,7 +187,7 @@ async def test_integration_with_async_context(self): async def test_async_usage(): with spinner as ctx_spinner: ctx_spinner.update_status("Initial message") - await asyncio.sleep(0.1) # Small delay + await anyio.sleep(0.1) ctx_spinner.tick() ctx_spinner.update_status("Updated message") @@ -547,7 +547,7 @@ async def get_then_fail(): # Keep the body alive long enough for the monitor to loop # through the first get(), sleep, second get() (fails), and # error handler using the cached end time. - await asyncio.sleep(0.2) + await anyio.sleep(0.2) # Should have gone through the error handler using cached end time assert call_count >= 2 diff --git a/python/packages/jumpstarter/jumpstarter/common/grpc.py b/python/packages/jumpstarter/jumpstarter/common/grpc.py index cf2f6690c..db2f49ab9 100644 --- a/python/packages/jumpstarter/jumpstarter/common/grpc.py +++ b/python/packages/jumpstarter/jumpstarter/common/grpc.py @@ -1,4 +1,3 @@ -import asyncio import base64 import logging import os @@ -8,8 +7,9 @@ from typing import Any, Sequence, Tuple from urllib.parse import urlparse +import anyio import grpc -from anyio import fail_after +from anyio import connect_tcp, fail_after from jumpstarter.common.exceptions import ConfigurationError, ConnectionError @@ -17,23 +17,27 @@ async def _try_connect_and_extract_cert( - ip_address: str, port: int, ssl_context: ssl.SSLContext, hostname: str, timeout: float + ip_address: str, port: int, ssl_context: ssl.SSLContext, hostname: str ) -> bytes: """ Try to connect to a single IP and extract its certificate chain. Returns the certificate chain in PEM format as bytes. Raises exception on failure. + + The caller is expected to enforce the overall timeout via an outer + fail_after scope so that a slow IP does not consume the entire budget. """ - logger.debug(f"Attempting TLS connection to {ip_address}:{port} (timeout={timeout}s)") - _, writer = await asyncio.wait_for( - asyncio.open_connection(ip_address, port, ssl=ssl_context, server_hostname=hostname), - timeout=timeout, - ) + logger.debug(f"Attempting TLS connection to {ip_address}:{port}") + stream = await connect_tcp(ip_address, port, tls=True, ssl_context=ssl_context, tls_hostname=hostname) logger.debug(f"Successfully connected to {ip_address}:{port}") try: - # Extract certificates - cert_chain = writer.get_extra_info("ssl_object")._sslobj.get_unverified_chain() + ssl_object = stream.extra(anyio.abc.TLSAttribute.ssl_object) + # CPython internal: _sslobj.get_unverified_chain() is not part of the + # public ssl module API. There is no public alternative for extracting + # the full certificate chain including untrusted intermediates. This + # will break on non-CPython implementations (PyPy, GraalPy). + cert_chain = ssl_object._sslobj.get_unverified_chain() root_certificates = "" for cert in cert_chain: root_certificates += cert.public_bytes() @@ -41,7 +45,7 @@ async def _try_connect_and_extract_cert( return root_certificates.encode() finally: - writer.close() + await stream.aclose() async def _ssl_channel_credentials_insecure(target: str, timeout: float) -> grpc.ChannelCredentials: # noqa: C901 @@ -63,64 +67,57 @@ async def _ssl_channel_credentials_insecure(target: str, timeout: float) -> grpc ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - # Resolve all IP addresses for the hostname - loop = asyncio.get_running_loop() - addr_info = await loop.getaddrinfo( + addr_info = await anyio.getaddrinfo( parsed.hostname, port, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM ) - # Log resolved IPs resolved_ips = [sockaddr[0] for _, _, _, _, sockaddr in addr_info] logger.debug( f"Resolved {parsed.hostname} to {len(resolved_ips)} IP(s): {', '.join(resolved_ips)}" ) - # Try all IPs in parallel - race for first success - # Wrap tasks to include IP info with results/exceptions + send_stream, receive_stream = anyio.create_memory_object_stream[ + tuple[str, bytes | None, Exception | None] + ](max_buffer_size=len(addr_info)) + async def try_with_ip(ip_address: str): - """Wrapper that returns (ip, result) on success or (ip, exception) on failure.""" try: result = await _try_connect_and_extract_cert( - ip_address, port, ssl_context, parsed.hostname, timeout + ip_address, port, ssl_context, parsed.hostname ) - return (ip_address, result, None) + await send_stream.send((ip_address, result, None)) except Exception as e: - return (ip_address, None, e) + await send_stream.send((ip_address, None, e)) - tasks = [] - for _family, _type, _proto, _canonname, sockaddr in addr_info: - ip_address = sockaddr[0] - task = asyncio.create_task(try_with_ip(ip_address)) - tasks.append(task) + async with anyio.create_task_group() as tg: + for _family, _type, _proto, _canonname, sockaddr in addr_info: + ip_address_str = sockaddr[0] + tg.start_soon(try_with_ip, ip_address_str) - # Process tasks as they complete - errors = {} + errors = {} + results_received = 0 + total_tasks = len(addr_info) - try: - for future in asyncio.as_completed(tasks): - ip_address, root_certificates, error = await future + while results_received < total_tasks: + ip_addr, root_certificates, error = await receive_stream.receive() + results_received += 1 if error is None: - # Success! Return immediately (cleanup in finally) - logger.debug(f"Using certificates from {ip_address}:{port}") + logger.debug(f"Using certificates from {ip_addr}:{port}") + tg.cancel_scope.cancel() return grpc.ssl_channel_credentials(root_certificates=root_certificates) - # This IP failed - log and continue trying other IPs if isinstance(error, ssl.SSLError): - logger.error(f"SSL error on {ip_address}:{port}: {error}") + logger.error(f"SSL error on {ip_addr}:{port}: {error}") else: - logger.warning(f"Failed to connect to {ip_address}:{port}: {type(error).__name__}: {error}") - errors[ip_address] = error + logger.warning( + f"Failed to connect to {ip_addr}:{port}: {type(error).__name__}: {error}" + ) + errors[ip_addr] = error - # All IPs failed raise ConnectionError( f"Failed connecting to {parsed.hostname}:{port} - all IPs exhausted. Errors: {errors}" ) - finally: - # Cancel any remaining tasks - for task in tasks: - if not task.done(): - task.cancel() except socket.gaierror as e: raise ConnectionError(f"Failed resolving {parsed.hostname}") from e except TimeoutError as e: @@ -156,7 +153,6 @@ def aio_secure_channel( def _override_default_grpc_options(grpc_options: dict[str, str | int] | None) -> Sequence[Tuple[str, Any]]: defaults = ( ("grpc.lb_policy_name", "round_robin"), - # we keep a low keepalive time to avoid idle timeouts on cloud load balancers ("grpc.keepalive_time_ms", 20000), ("grpc.keepalive_timeout_ms", 180000), ("grpc.http2.max_pings_without_data", 0), @@ -174,10 +170,8 @@ def translate_grpc_exceptions(): yield except grpc.aio.AioRpcError as e: if e.code().name == "UNAVAILABLE": - # tls or other connection errors raise ConnectionError(f"grpc error: {e.details()}") from None if e.code().name == "UNKNOWN": - # an error returned from our functions raise ConnectionError(f"grpc controller responded: {e.details()}") from None else: raise ConnectionError("grpc error") from e diff --git a/python/packages/jumpstarter/jumpstarter/common/ipaddr.py b/python/packages/jumpstarter/jumpstarter/common/ipaddr.py index 70944048d..9c4baa698 100644 --- a/python/packages/jumpstarter/jumpstarter/common/ipaddr.py +++ b/python/packages/jumpstarter/jumpstarter/common/ipaddr.py @@ -1,8 +1,9 @@ -import asyncio import logging import socket from ipaddress import ip_address +import anyio + def get_ip_address(logger: logging.Logger | None = None) -> str: """Get the IP address of the host machine""" @@ -26,22 +27,14 @@ def get_ip_address(logger: logging.Logger | None = None) -> str: return address -async def get_minikube_ip(profile: str = None, minikube: str = "minikube"): - # Create the subprocess with optional profile +async def get_minikube_ip(profile: str | None = None, minikube: str = "minikube"): cmd = [minikube, "ip"] if profile: cmd.extend(["-p", profile]) - process = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) - - # Wait for it to complete and get the output - stdout, stderr = await process.communicate() - - # Decode and strip whitespace - result = stdout.decode().strip() + result = await anyio.run_process(cmd) - # Optional: check if command was successful - if process.returncode != 0: - raise RuntimeError(stderr.decode()) + if result.returncode != 0: + raise RuntimeError(result.stderr.decode()) - return result + return result.stdout.decode().strip() diff --git a/python/packages/jumpstarter/jumpstarter/common/ipaddr_test.py b/python/packages/jumpstarter/jumpstarter/common/ipaddr_test.py index 7b07c6688..181bb8953 100644 --- a/python/packages/jumpstarter/jumpstarter/common/ipaddr_test.py +++ b/python/packages/jumpstarter/jumpstarter/common/ipaddr_test.py @@ -1,3 +1,4 @@ +from subprocess import CompletedProcess from unittest.mock import AsyncMock, patch import pytest @@ -8,70 +9,75 @@ class TestIPAddressDetection: """Test IP address detection functions.""" - @pytest.mark.asyncio - @patch("asyncio.create_subprocess_exec") - async def test_get_minikube_ip_success(self, mock_subprocess): - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"192.168.49.2\n", b"") - mock_process.returncode = 0 - mock_subprocess.return_value = mock_process + @pytest.mark.anyio + @patch("jumpstarter.common.ipaddr.anyio.run_process", new_callable=AsyncMock) + async def test_get_minikube_ip_success(self, mock_run_process): + mock_run_process.return_value = CompletedProcess( + args=["minikube", "ip"], + returncode=0, + stdout=b"192.168.49.2\n", + stderr=b"", + ) result = await get_minikube_ip() assert result == "192.168.49.2" - mock_subprocess.assert_called_once_with( - "minikube", - "ip", - stdout=-1, - stderr=-1, # asyncio.subprocess.PIPE constants + mock_run_process.assert_called_once_with(["minikube", "ip"]) + + @pytest.mark.anyio + @patch("jumpstarter.common.ipaddr.anyio.run_process", new_callable=AsyncMock) + async def test_get_minikube_ip_with_profile(self, mock_run_process): + mock_run_process.return_value = CompletedProcess( + args=["minikube", "ip", "-p", "test-profile"], + returncode=0, + stdout=b"192.168.49.3\n", + stderr=b"", ) - @pytest.mark.asyncio - @patch("asyncio.create_subprocess_exec") - async def test_get_minikube_ip_with_profile(self, mock_subprocess): - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"192.168.49.3\n", b"") - mock_process.returncode = 0 - mock_subprocess.return_value = mock_process - result = await get_minikube_ip("test-profile") assert result == "192.168.49.3" - mock_subprocess.assert_called_once_with("minikube", "ip", "-p", "test-profile", stdout=-1, stderr=-1) - - @pytest.mark.asyncio - @patch("asyncio.create_subprocess_exec") - async def test_get_minikube_ip_custom_binary(self, mock_subprocess): - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"10.0.0.5\n", b"") - mock_process.returncode = 0 - mock_subprocess.return_value = mock_process + mock_run_process.assert_called_once_with(["minikube", "ip", "-p", "test-profile"]) + + @pytest.mark.anyio + @patch("jumpstarter.common.ipaddr.anyio.run_process", new_callable=AsyncMock) + async def test_get_minikube_ip_custom_binary(self, mock_run_process): + mock_run_process.return_value = CompletedProcess( + args=["custom-minikube", "ip"], + returncode=0, + stdout=b"10.0.0.5\n", + stderr=b"", + ) result = await get_minikube_ip(minikube="custom-minikube") assert result == "10.0.0.5" - mock_subprocess.assert_called_once_with("custom-minikube", "ip", stdout=-1, stderr=-1) - - @pytest.mark.asyncio - @patch("asyncio.create_subprocess_exec") - async def test_get_minikube_ip_failure(self, mock_subprocess): - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"", b"error: cluster not found\n") - mock_process.returncode = 1 - mock_subprocess.return_value = mock_process + mock_run_process.assert_called_once_with(["custom-minikube", "ip"]) + + @pytest.mark.anyio + @patch("jumpstarter.common.ipaddr.anyio.run_process", new_callable=AsyncMock) + async def test_get_minikube_ip_failure(self, mock_run_process): + mock_run_process.return_value = CompletedProcess( + args=["minikube", "ip"], + returncode=1, + stdout=b"", + stderr=b"error: cluster not found\n", + ) with pytest.raises(RuntimeError, match="error: cluster not found"): await get_minikube_ip() - @pytest.mark.asyncio - @patch("asyncio.create_subprocess_exec") - async def test_get_minikube_ip_profile_and_custom_binary(self, mock_subprocess): - mock_process = AsyncMock() - mock_process.communicate.return_value = (b"172.16.0.1\n", b"") - mock_process.returncode = 0 - mock_subprocess.return_value = mock_process + @pytest.mark.anyio + @patch("jumpstarter.common.ipaddr.anyio.run_process", new_callable=AsyncMock) + async def test_get_minikube_ip_profile_and_custom_binary(self, mock_run_process): + mock_run_process.return_value = CompletedProcess( + args=["my-minikube", "ip", "-p", "my-profile"], + returncode=0, + stdout=b"172.16.0.1\n", + stderr=b"", + ) result = await get_minikube_ip("my-profile", "my-minikube") assert result == "172.16.0.1" - mock_subprocess.assert_called_once_with("my-minikube", "ip", "-p", "my-profile", stdout=-1, stderr=-1) + mock_run_process.assert_called_once_with(["my-minikube", "ip", "-p", "my-profile"]) diff --git a/python/packages/jumpstarter/jumpstarter/config/client.py b/python/packages/jumpstarter/jumpstarter/config/client.py index c2962d6d2..b7b457a6a 100644 --- a/python/packages/jumpstarter/jumpstarter/config/client.py +++ b/python/packages/jumpstarter/jumpstarter/config/client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import errno import os import tempfile @@ -10,7 +9,9 @@ from pathlib import Path from typing import Annotated, ClassVar, Literal, Optional, Self +import anyio import grpc +import sniffio import yaml from anyio.from_thread import BlockingPortal, start_blocking_portal from pydantic import ( @@ -41,9 +42,11 @@ def _blocking_compat(f): @wraps(f) def wrapper(*args, **kwargs): try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(f(*args, **kwargs)) + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + async def _run(): + return await f(*args, **kwargs) + return anyio.run(_run) else: return f(*args, **kwargs) diff --git a/python/packages/jumpstarter/jumpstarter/config/client_config_test.py b/python/packages/jumpstarter/jumpstarter/config/client_config_test.py index 6f502179c..4904c1980 100644 --- a/python/packages/jumpstarter/jumpstarter/config/client_config_test.py +++ b/python/packages/jumpstarter/jumpstarter/config/client_config_test.py @@ -413,7 +413,7 @@ def test_client_config_delete_does_not_exist_raises(): _get_path_mock.assert_called_once_with("xyz") -@pytest.mark.asyncio +@pytest.mark.anyio async def test_create_lease_passes_exporter_name(): config = ClientConfigV1Alpha1( alias="testclient", @@ -446,7 +446,7 @@ async def test_create_lease_passes_exporter_name(): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_list_leases_paginates(): from jumpstarter.client.grpc import Lease, LeaseList @@ -495,7 +495,7 @@ async def test_list_leases_paginates(): assert calls[2].kwargs["page_token"] == "token2" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_list_leases_single_page(): from jumpstarter.client.grpc import Lease, LeaseList @@ -528,7 +528,7 @@ async def test_list_leases_single_page(): mock_service.ListLeases.assert_awaited_once() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_list_exporters_paginates(): from jumpstarter.client.grpc import Exporter, ExporterList @@ -571,7 +571,7 @@ async def test_list_exporters_paginates(): assert calls[1].kwargs["page_token"] == "tok1" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_list_exporters_with_leases_propagates_page_size(): from jumpstarter.client.grpc import Exporter, ExporterList, Lease, LeaseList diff --git a/python/packages/jumpstarter/jumpstarter/exporter/session_test.py b/python/packages/jumpstarter/jumpstarter/exporter/session_test.py index 9dcfcee5a..5a09ffe25 100644 --- a/python/packages/jumpstarter/jumpstarter/exporter/session_test.py +++ b/python/packages/jumpstarter/jumpstarter/exporter/session_test.py @@ -52,8 +52,12 @@ def test_get_report_includes_descriptions(): ) # Call GetReport - import asyncio - response = asyncio.run(session.GetReport(empty_pb2.Empty(), None)) + import anyio + + async def _get_report(): + return await session.GetReport(empty_pb2.Empty(), None) + + response = anyio.run(_get_report) # Build a map of uuid -> report for easy lookup reports_by_uuid = {r.uuid: r for r in response.reports} @@ -133,8 +137,12 @@ def test_empty_description_not_included(): root_device=driver, ) - import asyncio - response = asyncio.run(session.GetReport(empty_pb2.Empty(), None)) + import anyio + + async def _get_report(): + return await session.GetReport(empty_pb2.Empty(), None) + + response = anyio.run(_get_report) # Empty string should not be included in the report reports_by_uuid = {r.uuid: r for r in response.reports} @@ -241,8 +249,12 @@ def test_methods_description_included_in_getreport(): root_device=driver, ) - import asyncio - response = asyncio.run(session.GetReport(empty_pb2.Empty(), None)) + import anyio + + async def _get_report(): + return await session.GetReport(empty_pb2.Empty(), None) + + response = anyio.run(_get_report) # Find the driver's report reports_by_uuid = {r.uuid: r for r in response.reports} diff --git a/python/packages/jumpstarter/jumpstarter/streams/common.py b/python/packages/jumpstarter/jumpstarter/streams/common.py index 40e650ca5..8bb8e3258 100644 --- a/python/packages/jumpstarter/jumpstarter/streams/common.py +++ b/python/packages/jumpstarter/jumpstarter/streams/common.py @@ -1,5 +1,5 @@ -import asyncio import logging +from asyncio import InvalidStateError from contextlib import asynccontextmanager, suppress from anyio import ( @@ -26,7 +26,7 @@ async def copy_stream(dst: AnyByteStream, src: AnyByteStream): OSError, ): await dst.send_eof() - except (BrokenResourceError, ClosedResourceError, asyncio.InvalidStateError) as e: + except (BrokenResourceError, ClosedResourceError, InvalidStateError) as e: if isinstance(e.__cause__, BrokenPipeError): # BrokenPipeError (EPIPE) = writing to a closed pipe during normal teardown logger.debug("stream copy interrupted (%s): %s", type(e).__name__, e) diff --git a/python/packages/jumpstarter/jumpstarter/streams/router.py b/python/packages/jumpstarter/jumpstarter/streams/router.py index b626ad6b3..c9b078c15 100644 --- a/python/packages/jumpstarter/jumpstarter/streams/router.py +++ b/python/packages/jumpstarter/jumpstarter/streams/router.py @@ -1,6 +1,6 @@ -import asyncio import contextlib import logging +from asyncio import InvalidStateError from dataclasses import dataclass, field import grpc @@ -57,13 +57,13 @@ async def receive(self) -> bytes: return b"" async def send_eof(self): - with contextlib.suppress(grpc.aio.AioRpcError, asyncio.exceptions.InvalidStateError): + with contextlib.suppress(grpc.aio.AioRpcError, InvalidStateError): await self.context.write(self.cls(frame_type=router_pb2.FRAME_TYPE_GOAWAY)) if isinstance(self.context, grpc.aio.StreamStreamCall): await self.context.done_writing() async def aclose(self): - with contextlib.suppress(grpc.aio.AioRpcError, asyncio.exceptions.InvalidStateError): + with contextlib.suppress(grpc.aio.AioRpcError, InvalidStateError): await self.send_eof() if isinstance(self.context, grpc._cython.cygrpc._ServicerContext): await self.context.abort(grpc.StatusCode.ABORTED, "RouterStream: aclose") diff --git a/python/packages/jumpstarter/pyproject.toml b/python/packages/jumpstarter/pyproject.toml index b221222c4..f392a5162 100644 --- a/python/packages/jumpstarter/pyproject.toml +++ b/python/packages/jumpstarter/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "jumpstarter-protocol", "pyyaml>=6.0.2", "anyio>=4.4.0,!=4.6.2", + "sniffio>=1.3.0", "aiohttp>=3.10.5", "yarl>=1.6.0", "pydantic>=2.8.2",