diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 0bdfa6262..26c1cf035 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -17,6 +17,7 @@ import asyncio from collections.abc import Callable, Generator from contextlib import contextmanager +import threading import time from typing import Any @@ -149,10 +150,12 @@ def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) with pubsub_context() as x: # Create a list to capture received messages received_messages: list[Any] = [] + msg_event = threading.Event() # Define callback function that stores received messages def callback(message: Any, _: Any) -> None: received_messages.append(message) + msg_event.set() # Subscribe to the topic with our callback x.subscribe(topic, callback) @@ -160,10 +163,8 @@ def callback(message: Any, _: Any) -> None: # Publish the first value to the topic x.publish(topic, values[0]) - # Give Redis time to process the message if needed - time.sleep(0.1) + assert msg_event.wait(timeout=1.0), "Timed out waiting for message" - print("RECEIVED", received_messages) # Verify the callback was called with the correct value assert len(received_messages) == 1 assert received_messages[0] == values[0] @@ -178,13 +179,17 @@ def test_multiple_subscribers( # Create lists to capture received messages for each subscriber received_messages_1: list[Any] = [] received_messages_2: list[Any] = [] + event_1 = threading.Event() + event_2 = threading.Event() # Define callback functions def callback_1(message: Any, topic: Any) -> None: received_messages_1.append(message) + event_1.set() def callback_2(message: Any, topic: Any) -> None: received_messages_2.append(message) + event_2.set() # Subscribe both callbacks to the same topic x.subscribe(topic, callback_1) @@ -193,8 +198,8 @@ def callback_2(message: Any, topic: Any) -> None: # Publish the first value x.publish(topic, values[0]) - # Give Redis time to process the message if needed - time.sleep(0.1) + assert event_1.wait(timeout=1.0), "Timed out waiting for subscriber 1" + assert event_2.wait(timeout=1.0), "Timed out waiting for subscriber 2" # Verify both callbacks received the message assert len(received_messages_1) == 1 @@ -238,21 +243,24 @@ def test_multiple_messages( with pubsub_context() as x: # Create a list to capture received messages received_messages: list[Any] = [] + all_received = threading.Event() + + # Publish the rest of the values (after the first one used in basic tests) + messages_to_send = values[1:] if len(values) > 1 else values # Define callback function def callback(message: Any, topic: Any) -> None: received_messages.append(message) + if len(received_messages) >= len(messages_to_send): + all_received.set() # Subscribe to the topic x.subscribe(topic, callback) - # Publish the rest of the values (after the first one used in basic tests) - messages_to_send = values[1:] if len(values) > 1 else values for msg in messages_to_send: x.publish(topic, msg) - # Give Redis time to process the messages if needed - time.sleep(0.2) + assert all_received.wait(timeout=1.0), "Timed out waiting for all messages" # Verify all messages were received in order assert len(received_messages) == len(messages_to_send)