Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ data/
site/
secrets/
config.toml
mitm_logs
mitm_logs
dist/
Original file line number Diff line number Diff line change
Expand Up @@ -563,15 +563,17 @@ def _routine_runner_for_context(ctx: ServerContext) -> RoutineRunner:


def list_scenes_for_device(ctx: ServerContext, device_id: str) -> list[dict[str, Any]]:
scenes = _scene_state(ctx)["scenes"]
state = _scene_state(ctx)
scenes = state["scenes"]
home_id = state["home_id"]
filtered: list[dict[str, Any]] = []
for scene in scenes:
if not isinstance(scene, dict):
continue
scene_device = get_value(scene, "device_id", "deviceId", "duid")
if scene_device and str(scene_device) != str(device_id):
continue
filtered.append(build_scene_payload(scene, home_id=None, include_device_context=False))
filtered.append(build_scene_payload(scene, home_id=home_id, include_device_context=True))
return filtered


Expand Down Expand Up @@ -693,4 +695,3 @@ def apply_update(updated_scene: dict[str, Any], inventory: dict[str, Any]) -> No

updated_scene, home_id = _replace_inventory_scene(ctx, scene_id=scene_id, scene_updater=apply_update)
return build_scene_payload(updated_scene, home_id=home_id, include_device_context=True)

Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
self._counter = 0
self._lock = threading.Lock()
self._conn_protocol_levels: dict[str, int] = {}
self._conn_endpoints: dict[str, tuple[socket.socket, socket.socket]] = {}
self._pending_onboarding_auth: dict[str, dict[str, str]] = {}
self._trace_queue: queue.Queue[tuple[str, str, bytes] | None] = queue.Queue()
self._trace_thread: threading.Thread | None = None
self._protocol_auth = (
Expand All @@ -83,6 +85,38 @@ def _next_conn(self) -> str:
self._counter += 1
return str(self._counter)

def _register_conn_endpoints(self, conn_id: str, client_conn: socket.socket, backend_conn: socket.socket) -> None:
with self._lock:
self._conn_endpoints[conn_id] = (client_conn, backend_conn)

def _pop_conn_endpoints(self, conn_id: str) -> tuple[socket.socket, socket.socket] | None:
with self._lock:
return self._conn_endpoints.pop(conn_id, None)

def _close_conn_endpoints(self, conn_id: str) -> None:
endpoints = self._pop_conn_endpoints(conn_id)
if endpoints is None:
return
for endpoint in endpoints:
try:
endpoint.close()
except OSError:
pass

def _set_pending_onboarding_auth(self, conn_id: str, candidate: dict[str, str]) -> None:
with self._lock:
self._pending_onboarding_auth[conn_id] = dict(candidate)

def _get_pending_onboarding_auth(self, conn_id: str) -> dict[str, str] | None:
with self._lock:
candidate = self._pending_onboarding_auth.get(conn_id)
return dict(candidate) if candidate is not None else None

def _pop_pending_onboarding_auth(self, conn_id: str) -> dict[str, str] | None:
with self._lock:
candidate = self._pending_onboarding_auth.pop(conn_id, None)
return dict(candidate) if candidate is not None else None

@staticmethod
def _decode_remaining_length(data: bytes, start: int) -> tuple[int | None, int]:
multiplier = 1
Expand Down Expand Up @@ -187,45 +221,135 @@ def _expected_bootstrap_credentials(self) -> tuple[str, str, str] | None:
return username, password, client_id

def _authorize_connect_packet(self, packet: bytes) -> tuple[bool, str, dict[str, Any] | None]:
authorized, reason, info, _candidate = self._authorize_connect_packet_for_client(packet, client_ip="")
return authorized, reason, info

def _authorize_connect_packet_for_client(
self,
packet: bytes,
*,
client_ip: str,
) -> tuple[bool, str, dict[str, Any] | None, dict[str, str] | None]:
info = parse_mqtt_connect_packet(packet)
if info is None:
return False, "invalid_connect_packet", None
return False, "invalid_connect_packet", None, None

username = str(info.get("username") or "").strip()
password = str(info.get("password") or "").strip()
client_id = str(info.get("client_id") or "").strip()
if not username or not password:
return False, "missing_mqtt_credentials", info
return False, "missing_mqtt_credentials", info, None

if self._protocol_auth is not None and self._protocol_auth_enabled():
authorized, auth_reason, _matched_user = self._protocol_auth.verify_user_mqtt_credentials(username, password)
if authorized:
return True, auth_reason, info
return True, auth_reason, info, None

bootstrap_credentials = self._expected_bootstrap_credentials()
if bootstrap_credentials is not None:
expected_username, expected_password, expected_client_id = bootstrap_credentials
if username == expected_username and password == expected_password:
if expected_client_id and client_id and client_id != expected_client_id:
return False, "invalid_bootstrap_client_id", info
return True, "bootstrap", info
return False, "invalid_bootstrap_client_id", info, None
return True, "bootstrap", info, None

if self.runtime_credentials is not None:
authorized, auth_reason, _matched_device = self.runtime_credentials.verify_device_mqtt_credentials(
username=username,
password=password,
)
if authorized:
return True, auth_reason, info
return True, auth_reason, info, None
if auth_reason == "device_mqtt_password_missing":
recovered_device = self.runtime_credentials.recover_device_mqtt_password(
username=username,
password=password,
)
if recovered_device is not None:
return True, "device_mqtt_recovered", info
return True, "device_mqtt_recovered", info, None
if auth_reason == "unknown_device_mqtt_username":
candidate = self._resolve_onboarding_device_mqtt_candidate(
client_ip=client_ip,
username=username,
password=password,
)
if candidate is not None:
return True, "device_mqtt_onboarding_pending", info, candidate

return False, "invalid_mqtt_credentials", info, None

return False, "invalid_mqtt_credentials", info
def _resolve_onboarding_device_mqtt_candidate(
self,
*,
client_ip: str,
username: str,
password: str,
) -> dict[str, str] | None:
if self.runtime_state is None or self.runtime_credentials is None:
return None
candidate = self.runtime_state.onboarding_device_mqtt_candidate(client_ip=client_ip)
if candidate is None:
return None
device = self.runtime_credentials.resolve_device(
did=str(candidate.get("did") or ""),
duid=str(candidate.get("duid") or ""),
)
if device is None:
return None
existing_username = str(device.get("device_mqtt_usr") or "").strip()
existing_password = str(device.get("device_mqtt_pass") or "").strip()
if existing_username or existing_password:
return None
return {
"did": str(device.get("did") or "").strip(),
"duid": str(device.get("duid") or "").strip(),
"name": str(device.get("name") or candidate.get("name") or "").strip(),
"username": username.strip(),
"password": password.strip(),
"client_ip": client_ip.strip(),
}

def _confirm_pending_onboarding_auth(self, conn_id: str, *, direction: str, topic: str) -> bool:
if direction != "c2b" or self.runtime_credentials is None:
return True
candidate = self._get_pending_onboarding_auth(conn_id)
if candidate is None:
return True
expected_topic = f"rr/d/i/{candidate['did']}/{candidate['username']}"
if topic != expected_topic:
self.logger.warning(
"[conn %s] rejected provisional onboarding MQTT session expected_topic=%s got=%s",
conn_id,
expected_topic,
topic,
)
self._pop_pending_onboarding_auth(conn_id)
self._close_conn_endpoints(conn_id)
return False
learned = self.runtime_credentials.confirm_device_mqtt_credentials(
did=candidate.get("did", ""),
duid=candidate.get("duid", ""),
username=candidate["username"],
password=candidate["password"],
)
self._pop_pending_onboarding_auth(conn_id)
if learned is None:
self.logger.warning(
"[conn %s] failed to persist confirmed onboarding MQTT credentials did=%s duid=%s",
conn_id,
candidate.get("did", ""),
candidate.get("duid", ""),
)
self._close_conn_endpoints(conn_id)
return False
self.logger.info(
"[conn %s] learned onboarding MQTT credentials did=%s duid=%s username=%s",
conn_id,
learned.get("did", ""),
learned.get("duid", ""),
candidate["username"],
)
return True

@classmethod
def _extract_publish(cls, packet: bytes, protocol_level: int | None = None) -> tuple[str | None, bytes | None]:
Expand Down Expand Up @@ -408,6 +532,8 @@ def _trace_packet(self, conn_id: str, direction: str, packet: bytes) -> None:
topic, payload = self._extract_publish(packet, self._get_conn_protocol_level(conn_id))
if topic is None or payload is None:
return
if not self._confirm_pending_onboarding_auth(conn_id, direction=direction, topic=topic):
return
if self.runtime_state is not None:
self.runtime_state.record_mqtt_message(
conn_id=conn_id,
Expand Down Expand Up @@ -609,7 +735,10 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None
self.logger.warning("[conn %s] client closed before MQTT CONNECT", conn_id)
return
connect_packet, initial_remainder = first_packet
authorized, auth_reason, connect_info = self._authorize_connect_packet(connect_packet)
authorized, auth_reason, connect_info, onboarding_candidate = self._authorize_connect_packet_for_client(
connect_packet,
client_ip=addr[0],
)
if connect_info is not None:
protocol_level = connect_info.get("protocol_level")
if isinstance(protocol_level, int):
Expand Down Expand Up @@ -637,6 +766,9 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None

backend = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
backend.connect((self.backend_host, self.backend_port))
self._register_conn_endpoints(conn_id, tls_conn, backend)
if onboarding_candidate is not None:
self._set_pending_onboarding_auth(conn_id, onboarding_candidate)
c2b_frame_buf = bytearray(initial_remainder)
for packet in self._extract_packets(c2b_frame_buf):
self._queue_trace_packet(conn_id, "c2b", packet)
Expand All @@ -655,6 +787,8 @@ def _handle_client(self, tls_conn: ssl.SSLSocket, addr: tuple[str, int]) -> None
except Exception as exc:
self.logger.error("[conn %s] connection error: %s", conn_id, exc)
finally:
self._pop_pending_onboarding_auth(conn_id)
self._pop_conn_endpoints(conn_id)
if not relay_started:
for endpoint in (tls_conn, backend):
if endpoint is None:
Expand Down
33 changes: 30 additions & 3 deletions src/roborock_local_server/bundled_backend/shared/protocol_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def _md5hex(value: str) -> str:
return hashlib.md5(value.encode("utf-8")).hexdigest()


def _md5hex_bytes(value: bytes) -> str:
return hashlib.md5(value).hexdigest()


def _parse_json_body_param_map(body_params: dict[str, list[str]]) -> dict[str, Any]:
for raw in body_params.get("__json") or []:
try:
Expand Down Expand Up @@ -63,18 +67,26 @@ def _build_hawk_mac(
path: str,
query_values: dict[str, Any] | None,
form_values: dict[str, Any] | None,
json_body: bytes | str | None = None,
timestamp: int,
nonce: str,
) -> str:
query_hash = _process_extra_hawk_values(query_values)
if json_body is None:
body_hash = _process_extra_hawk_values(form_values)
else:
if isinstance(json_body, str):
json_body = json_body.encode("utf-8")
body_hash = _md5hex_bytes(json_body)
prestr = ":".join(
[
hawk_id,
hawk_session,
nonce,
str(timestamp),
_md5hex(path),
_process_extra_hawk_values(query_values),
_process_extra_hawk_values(form_values),
query_hash,
body_hash,
]
)
return base64.b64encode(hmac.new(hawk_key.encode(), prestr.encode(), hashlib.sha256).digest()).decode()
Expand Down Expand Up @@ -162,18 +174,24 @@ def build_hawk_authorization(
path: str,
query_values: dict[str, Any] | None = None,
form_values: dict[str, Any] | None = None,
json_body: bytes | str | None = None,
timestamp: int | None = None,
nonce: str | None = None,
) -> str:
ts = int(time.time() if timestamp is None else timestamp)
normalized_nonce = _clean_str(nonce) or secrets.token_urlsafe(6)
if json_body is None and isinstance(form_values, Mapping):
raw_json = form_values.get("__json")
if isinstance(raw_json, (str, bytes)):
json_body = raw_json
mac = _build_hawk_mac(
hawk_id=user.hawk_id,
hawk_session=user.hawk_session,
hawk_key=user.hawk_key,
path=path,
query_values=query_values,
form_values=form_values,
json_body=json_body,
timestamp=ts,
nonce=normalized_nonce,
)
Expand Down Expand Up @@ -533,6 +551,7 @@ def verify_hawk(
query_params: dict[str, list[str]],
body_params: dict[str, list[str]],
headers: Mapping[str, str],
raw_body: bytes | None = None,
now_ts: float | None = None,
) -> tuple[bool, str]:
availability = self.availability()
Expand Down Expand Up @@ -565,13 +584,21 @@ def verify_hawk(
if not nonce:
return False, "missing_hawk_nonce"

json_body: bytes | None = None
if body_params.get("__json"):
if raw_body is not None:
json_body = raw_body
else:
json_raw = next((value for value in body_params.get("__json", []) if isinstance(value, str)), "")
json_body = json_raw.encode("utf-8")
expected_mac = _build_hawk_mac(
hawk_id=user.hawk_id,
hawk_session=user.hawk_session,
hawk_key=user.hawk_key,
path=path,
query_values=_normalize_param_values(query_params),
form_values=_normalize_param_values(body_params, include_json=True),
form_values=None if json_body is not None else _normalize_param_values(body_params, include_json=True),
json_body=json_body,
timestamp=timestamp,
nonce=nonce,
)
Expand Down
Loading