-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcoordinator.py
More file actions
141 lines (121 loc) · 5.18 KB
/
coordinator.py
File metadata and controls
141 lines (121 loc) · 5.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Coordinator for the SSH Command integration.
The SshCommandCoordinator is the single owner of all SSH I/O for the
integration. It encapsulates connection management and command execution
so that the service handler in __init__.py is a pure dispatcher.
Preferred HA pattern: coordinator (or "client") owns I/O; the service
handler validates input and delegates to this class.
"""
from __future__ import annotations
import logging
import socket
from pathlib import Path
from typing import Any
from asyncssh import HostKeyNotVerifiable, KeyImportError, PermissionDenied, connect, read_known_hosts
from homeassistant.const import CONF_USERNAME, CONF_PASSWORD, CONF_HOST, CONF_COMMAND, CONF_TIMEOUT
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ServiceValidationError
from .const import (
DOMAIN,
CONF_KEY_FILE,
CONF_INPUT,
CONF_CHECK_KNOWN_HOSTS,
CONF_KNOWN_HOSTS,
CONF_CLIENT_KEYS,
CONF_CHECK,
CONF_OUTPUT,
CONF_ERROR,
CONF_EXIT_STATUS,
CONST_DEFAULT_TIMEOUT,
)
_LOGGER = logging.getLogger(__name__)
class SshCommandCoordinator:
"""Single owner of all SSH I/O for the SSH Command integration.
Preferred HA pattern: coordinator owns I/O; the service handler in
__init__.py validates input and delegates execution to this class.
"""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the coordinator."""
self.hass = hass
async def async_execute(self, data: dict[str, Any]) -> dict[str, Any]:
"""Execute an SSH command and return stdout, stderr and exit status."""
host = data.get(CONF_HOST)
username = data.get(CONF_USERNAME)
password = data.get(CONF_PASSWORD)
key_file = data.get(CONF_KEY_FILE)
command = data.get(CONF_COMMAND)
input_data = data.get(CONF_INPUT)
check_known_hosts = data.get(CONF_CHECK_KNOWN_HOSTS, True)
known_hosts = data.get(CONF_KNOWN_HOSTS)
timeout = data.get(CONF_TIMEOUT, CONST_DEFAULT_TIMEOUT)
if input_data:
if await self.hass.async_add_executor_job(Path(input_data).exists):
input_data = await self.hass.async_add_executor_job(Path(input_data).read_text)
conn_kwargs = {
CONF_HOST: host,
CONF_USERNAME: username,
CONF_PASSWORD: password,
CONF_CLIENT_KEYS: key_file,
CONF_KNOWN_HOSTS: await self._resolve_known_hosts(check_known_hosts, known_hosts),
"connect_timeout": timeout,
}
run_kwargs: dict[str, Any] = {
CONF_COMMAND: command,
CONF_CHECK: False,
CONF_TIMEOUT: timeout,
}
if input_data:
run_kwargs[CONF_INPUT] = input_data
try:
async with connect(**conn_kwargs) as conn:
result = await conn.run(**run_kwargs)
except HostKeyNotVerifiable as exc:
_LOGGER.warning("Host key not verifiable for %s: %s", host, exc)
raise ServiceValidationError(
"The host key could not be verified.",
translation_domain=DOMAIN,
translation_key="host_key_not_verifiable",
) from exc
except KeyImportError as exc:
_LOGGER.warning("Invalid key file for %s@%s: %s", username, host, exc)
raise ServiceValidationError(
"The key file is not a valid private key.",
translation_domain=DOMAIN,
translation_key="invalid_key_file",
) from exc
except PermissionDenied as exc:
_LOGGER.warning("SSH login failed for %s@%s: %s", username, host, exc)
raise ServiceValidationError(
"SSH login failed.",
translation_domain=DOMAIN,
translation_key="login_failed",
) from exc
except TimeoutError as exc:
_LOGGER.warning("SSH connection to %s timed out: %s", host, exc)
raise ServiceValidationError(
"Connection timed out.",
translation_domain=DOMAIN,
translation_key="connection_timed_out",
) from exc
except OSError as exc:
if isinstance(exc, socket.gaierror):
_LOGGER.warning("Host %s is not reachable: %s", host, exc)
raise ServiceValidationError(
"Host is not reachable.",
translation_domain=DOMAIN,
translation_key="host_not_reachable",
) from exc
raise
return {
CONF_OUTPUT: result.stdout,
CONF_ERROR: result.stderr,
CONF_EXIT_STATUS: result.exit_status,
}
async def _resolve_known_hosts(self, check_known_hosts: bool, known_hosts: str | None) -> str | None:
"""Resolve the known_hosts value for the SSH connection."""
if not check_known_hosts:
return None
if not known_hosts:
known_hosts = str(Path("~", ".ssh", "known_hosts").expanduser())
if await self.hass.async_add_executor_job(Path(known_hosts).exists):
return await self.hass.async_add_executor_job(read_known_hosts, known_hosts)
return known_hosts