Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions dimos/protocol/pubsub/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
from collections.abc import Callable, Generator
from contextlib import contextmanager
import threading
import time
from typing import Any

Expand Down Expand Up @@ -149,21 +150,21 @@ def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any])
with pubsub_context() as x:
# Create a list to capture received messages
received_messages: list[Any] = []
msg_event = threading.Event()

# Define callback function that stores received messages
def callback(message: Any, _: Any) -> None:
received_messages.append(message)
msg_event.set()

# Subscribe to the topic with our callback
x.subscribe(topic, callback)

# Publish the first value to the topic
x.publish(topic, values[0])

# Give Redis time to process the message if needed
time.sleep(0.1)
assert msg_event.wait(timeout=1.0), "Timed out waiting for message"

print("RECEIVED", received_messages)
# Verify the callback was called with the correct value
assert len(received_messages) == 1
assert received_messages[0] == values[0]
Expand All @@ -178,13 +179,17 @@ def test_multiple_subscribers(
# Create lists to capture received messages for each subscriber
received_messages_1: list[Any] = []
received_messages_2: list[Any] = []
event_1 = threading.Event()
event_2 = threading.Event()

# Define callback functions
def callback_1(message: Any, topic: Any) -> None:
received_messages_1.append(message)
event_1.set()

def callback_2(message: Any, topic: Any) -> None:
received_messages_2.append(message)
event_2.set()

# Subscribe both callbacks to the same topic
x.subscribe(topic, callback_1)
Expand All @@ -193,8 +198,8 @@ def callback_2(message: Any, topic: Any) -> None:
# Publish the first value
x.publish(topic, values[0])

# Give Redis time to process the message if needed
time.sleep(0.1)
assert event_1.wait(timeout=1.0), "Timed out waiting for subscriber 1"
assert event_2.wait(timeout=1.0), "Timed out waiting for subscriber 2"

# Verify both callbacks received the message
assert len(received_messages_1) == 1
Expand Down Expand Up @@ -238,21 +243,24 @@ def test_multiple_messages(
with pubsub_context() as x:
# Create a list to capture received messages
received_messages: list[Any] = []
all_received = threading.Event()

# Publish the rest of the values (after the first one used in basic tests)
messages_to_send = values[1:] if len(values) > 1 else values

# Define callback function
def callback(message: Any, topic: Any) -> None:
received_messages.append(message)
if len(received_messages) >= len(messages_to_send):
all_received.set()

# Subscribe to the topic
x.subscribe(topic, callback)

# Publish the rest of the values (after the first one used in basic tests)
messages_to_send = values[1:] if len(values) > 1 else values
for msg in messages_to_send:
x.publish(topic, msg)

# Give Redis time to process the messages if needed
time.sleep(0.2)
assert all_received.wait(timeout=1.0), "Timed out waiting for all messages"

# Verify all messages were received in order
assert len(received_messages) == len(messages_to_send)
Expand Down