|
4 | 4 | import copy |
5 | 5 | import datetime |
6 | 6 | from collections.abc import Callable, Generator |
| 7 | +from typing import Any |
7 | 8 | from unittest.mock import AsyncMock, Mock, patch |
8 | 9 |
|
9 | 10 | import aiomqtt |
@@ -31,17 +32,62 @@ def auto_fast_backoff(fast_backoff_fixture: None) -> None: |
31 | 32 | """Automatically use the fast backoff fixture.""" |
32 | 33 |
|
33 | 34 |
|
34 | | -@pytest.fixture |
35 | | -def mock_mqtt_client() -> Generator[AsyncMock, None, None]: |
36 | | - """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" |
| 35 | +class FakeAsyncIterator: |
| 36 | + """Fake async iterator that waits for messages to arrive, but they never do. |
| 37 | +
|
| 38 | + This is used for testing exceptions in other client functions. |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self) -> None: |
| 42 | + self.loop = True |
| 43 | + |
| 44 | + def __aiter__(self): |
| 45 | + return self |
| 46 | + |
| 47 | + async def __anext__(self) -> None: |
| 48 | + """Iterator that does not generate any messages.""" |
| 49 | + while self.loop: |
| 50 | + await asyncio.sleep(0.01) |
| 51 | + |
| 52 | + |
| 53 | +@pytest.fixture(name="message_iterator") |
| 54 | +def message_iterator_fixture() -> FakeAsyncIterator: |
| 55 | + """Fixture to provide a side effect for creating the MQTT client.""" |
| 56 | + return FakeAsyncIterator() |
| 57 | + |
| 58 | + |
| 59 | +@pytest.fixture(name="mock_client") |
| 60 | +def mock_client_fixture(message_iterator: FakeAsyncIterator) -> AsyncMock: |
| 61 | + """Fixture to provide a side effect for creating the MQTT client.""" |
37 | 62 | mock_client = AsyncMock() |
38 | | - mock_client.messages = FakeAsyncIterator() |
| 63 | + mock_client.messages = message_iterator |
| 64 | + return mock_client |
| 65 | + |
| 66 | + |
| 67 | +@pytest.fixture(name="create_client_side_effect") |
| 68 | +def create_client_side_effect_fixture() -> Exception | None: |
| 69 | + """Fixture to provide a side effect for creating the MQTT client.""" |
| 70 | + return None |
| 71 | + |
39 | 72 |
|
| 73 | +@pytest.fixture(name="mock_aenter_client") |
| 74 | +def mock_aenter_client_fixture(mock_client: AsyncMock, create_client_side_effect: Exception | None) -> AsyncMock: |
| 75 | + """Fixture to provide a side effect for creating the MQTT client.""" |
40 | 76 | mock_aenter = AsyncMock() |
41 | 77 | mock_aenter.return_value = mock_client |
| 78 | + mock_aenter.side_effect = create_client_side_effect |
| 79 | + return mock_aenter |
| 80 | + |
| 81 | + |
| 82 | +@pytest.fixture |
| 83 | +def mock_mqtt_client( |
| 84 | + mock_client: AsyncMock, |
| 85 | + mock_aenter_client: AsyncMock, |
| 86 | +) -> Generator[AsyncMock, None, None]: |
| 87 | + """Fixture to create a mock MQTT client with patched aiomqtt.Client.""" |
42 | 88 |
|
43 | 89 | mock_shim = Mock() |
44 | | - mock_shim.return_value.__aenter__ = mock_aenter |
| 90 | + mock_shim.return_value.__aenter__ = mock_aenter_client |
45 | 91 | mock_shim.return_value.__aexit__ = AsyncMock() |
46 | 92 |
|
47 | 93 | with patch("roborock.mqtt.roborock_session.aiomqtt.Client", mock_shim): |
@@ -114,21 +160,6 @@ async def test_publish_command(push_response: Callable[[bytes], None]) -> None: |
114 | 160 | assert not session.connected |
115 | 161 |
|
116 | 162 |
|
117 | | -class FakeAsyncIterator: |
118 | | - """Fake async iterator that waits for messages to arrive, but they never do. |
119 | | -
|
120 | | - This is used for testing exceptions in other client functions. |
121 | | - """ |
122 | | - |
123 | | - def __aiter__(self): |
124 | | - return self |
125 | | - |
126 | | - async def __anext__(self) -> None: |
127 | | - """Iterator that does not generate any messages.""" |
128 | | - while True: |
129 | | - await asyncio.sleep(1) |
130 | | - |
131 | | - |
132 | 163 | async def test_publish_failure(mock_mqtt_client: AsyncMock) -> None: |
133 | 164 | """Test an MQTT error is received when publishing a message.""" |
134 | 165 |
|
@@ -432,3 +463,64 @@ async def test_diagnostics_data(push_response: Callable[[bytes], None]) -> None: |
432 | 463 | assert data.get("subscribe_count") == 2 |
433 | 464 | assert data.get("dispatch_message_count") == 3 |
434 | 465 | assert data.get("close") == 1 |
| 466 | + |
| 467 | + |
| 468 | +@pytest.mark.parametrize( |
| 469 | + ("create_client_side_effect"), |
| 470 | + [ |
| 471 | + # Unauthorized |
| 472 | + aiomqtt.MqttCodeError(rc=135), |
| 473 | + ], |
| 474 | +) |
| 475 | +async def test_session_unauthorized_hook(mock_mqtt_client: AsyncMock, push_response: Callable[[bytes], None]) -> None: |
| 476 | + """Test the MQTT session.""" |
| 477 | + |
| 478 | + unauthorized = asyncio.Event() |
| 479 | + |
| 480 | + params = copy.deepcopy(FAKE_PARAMS) |
| 481 | + params.unauthorized_hook = unauthorized.set |
| 482 | + |
| 483 | + with pytest.raises(MqttSessionUnauthorized): |
| 484 | + await create_mqtt_session(params) |
| 485 | + |
| 486 | + assert unauthorized.is_set() |
| 487 | + |
| 488 | + |
| 489 | +async def test_session_unauthorized_after_start( |
| 490 | + mock_aenter_client: AsyncMock, |
| 491 | + message_iterator: FakeAsyncIterator, |
| 492 | + mock_mqtt_client: AsyncMock, |
| 493 | + push_response: Callable[[bytes], None], |
| 494 | +) -> None: |
| 495 | + """Test the MQTT session.""" |
| 496 | + |
| 497 | + # Configure a hook that is notified of unauthorized errors |
| 498 | + unauthorized = asyncio.Event() |
| 499 | + params = copy.deepcopy(FAKE_PARAMS) |
| 500 | + params.unauthorized_hook = unauthorized.set |
| 501 | + |
| 502 | + # The client will succeed on first connection attempt, then fail with |
| 503 | + # unauthorized messages on all future attempts. |
| 504 | + request_count = 0 |
| 505 | + |
| 506 | + def succeed_then_fail_unauthorized() -> Any: |
| 507 | + nonlocal request_count |
| 508 | + request_count += 1 |
| 509 | + if request_count == 1: |
| 510 | + return mock_mqtt_client |
| 511 | + raise aiomqtt.MqttCodeError(rc=135) |
| 512 | + |
| 513 | + mock_aenter_client.side_effect = succeed_then_fail_unauthorized |
| 514 | + # Don't produce messages, just exit and restart to reconnect |
| 515 | + message_iterator.loop = False |
| 516 | + |
| 517 | + push_response(mqtt_packet.gen_connack(rc=0, flags=2)) |
| 518 | + |
| 519 | + session = await create_mqtt_session(params) |
| 520 | + assert session.connected |
| 521 | + |
| 522 | + try: |
| 523 | + async with asyncio.timeout(10): |
| 524 | + assert await unauthorized.wait() |
| 525 | + finally: |
| 526 | + await session.close() |
0 commit comments