Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dspy/utils/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def all_done():
index, outcome = f.result()
except Exception:
pass

else:
if outcome != job_cancelled and results[index] is None:
# Check if this is an exception
Expand Down
76 changes: 76 additions & 0 deletions tests/utils/test_parallelizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import time

import pytest
import threading

from dspy.utils.parallelizer import ParallelExecutor
from dspy.dsp.utils.settings import thread_local_overrides


def test_worker_threads_independence():
Expand Down Expand Up @@ -83,3 +85,77 @@ def task(item):
assert str(executor.exceptions_map[2]) == "test error for 3"
assert isinstance(executor.exceptions_map[4], RuntimeError)
assert str(executor.exceptions_map[4]) == "test error for 5"


def test_thread_local_overrides_with_usage_tracker():

class MockUsageTracker:
def __init__(self):
self.tracked_items = []

def track(self, value):
self.tracked_items.append(value)

parent_thread_usage_tracker = MockUsageTracker()
parent_thread_overrides = {"usage_tracker": parent_thread_usage_tracker, "some_setting": "parent_value"}

override_token = thread_local_overrides.set(parent_thread_overrides)

try:
worker_thread_ids = set()
worker_thread_ids_lock = threading.Lock()

# Track all usage tracker instances seen (may be same instance reused across tasks in same thread)
all_usage_tracker_instances = []
usage_tracker_instances_lock = threading.Lock()

def task(item):

current_thread_id = threading.get_ident()

with worker_thread_ids_lock:
worker_thread_ids.add(current_thread_id)

current_thread_overrides = thread_local_overrides.get()

# Verify overrides were copied to worker thread
assert current_thread_overrides.get("some_setting") == "parent_value"

worker_thread_usage_tracker = current_thread_overrides.get("usage_tracker")

assert worker_thread_usage_tracker is not None
assert isinstance(worker_thread_usage_tracker, MockUsageTracker)

# Collect all tracker instances (same thread will get same instance)
with usage_tracker_instances_lock:
if worker_thread_usage_tracker not in all_usage_tracker_instances:
all_usage_tracker_instances.append(worker_thread_usage_tracker)

worker_thread_usage_tracker.track(item)

return item * 2

input_data = [1, 2, 3, 4, 5]
executor = ParallelExecutor(num_threads=3)
results = executor.execute(task, input_data)

assert results == [2, 4, 6, 8, 10]

# Verify that worker threads got their own deep copied usage trackers
# Even if only one thread was used, it should have a different instance than parent
assert len(all_usage_tracker_instances) >= 1, "At least one worker usage tracker should exist"

for worker_usage_tracker in all_usage_tracker_instances:
assert worker_usage_tracker is not parent_thread_usage_tracker, (
"Worker thread usage tracker should be deep copy, not same instance as parent"
)

assert len(parent_thread_usage_tracker.tracked_items) == 0, (
"Parent usage tracker should not be modified by worker threads"
)

total_tracked_items_count = sum(len(tracker.tracked_items) for tracker in all_usage_tracker_instances)
assert total_tracked_items_count == len(input_data), "All items should be tracked across worker threads"

finally:
thread_local_overrides.reset(override_token)