Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dspy/streaming/streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def use_streaming():
elif not iscoroutinefunction(program):
program = asyncify(program)

callbacks = settings.callbacks
callbacks = list(settings.callbacks)
status_streaming_callback = StatusStreamingCallback(status_message_provider)
if not any(isinstance(c, StatusStreamingCallback) for c in callbacks):
callbacks.append(status_streaming_callback)
Expand Down
84 changes: 84 additions & 0 deletions tests/streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,90 @@ def module_start_status_message(self, instance, inputs):
assert status_messages[2].message == "Predict starting!"


@pytest.mark.anyio
async def test_concurrent_status_message_providers():
class MyProgram(dspy.Module):
def __init__(self):
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
self.predict = dspy.Predict("question->answer")

def __call__(self, x: str):
question = self.generate_question(x=x)
return self.predict(question=question)

class MyStatusMessageProvider1(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Provider1: Tool starting!"

def tool_end_status_message(self, outputs):
return "Provider1: Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Provider1: Predict starting!"

class MyStatusMessageProvider2(StatusMessageProvider):
def tool_start_status_message(self, instance, inputs):
return "Provider2: Tool starting!"

def tool_end_status_message(self, outputs):
return "Provider2: Tool finished!"

def module_start_status_message(self, instance, inputs):
if isinstance(instance, dspy.Predict):
return "Provider2: Predict starting!"

# Store the original callbacks to verify they're not modified
original_callbacks = list(dspy.settings.callbacks)

lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}, {"answer": "green"}, {"answer": "yellow"}])

# Results storage for each thread
results = {}

async def run_with_provider1():
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider1())
output = program("sky")

status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value.message)

results["provider1"] = status_messages

async def run_with_provider2():
with dspy.context(lm=lm):
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider2())
output = program("ocean")

status_messages = []
async for value in output:
if isinstance(value, StatusMessage):
status_messages.append(value.message)

results["provider2"] = status_messages

# Run both tasks concurrently
await asyncio.gather(run_with_provider1(), run_with_provider2())

# Verify provider1 got its expected messages
assert len(results["provider1"]) == 3
assert results["provider1"][0] == "Provider1: Tool starting!"
assert results["provider1"][1] == "Provider1: Tool finished!"
assert results["provider1"][2] == "Provider1: Predict starting!"

# Verify provider2 got its expected messages
assert len(results["provider2"]) == 3
assert results["provider2"][0] == "Provider2: Tool starting!"
assert results["provider2"][1] == "Provider2: Tool finished!"
assert results["provider2"][2] == "Provider2: Predict starting!"

# Verify that the global callbacks were not modified
assert dspy.settings.callbacks == original_callbacks


@pytest.mark.llm_call
@pytest.mark.anyio
async def test_stream_listener_chat_adapter(lm_for_test):
Expand Down