diff --git a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py index 6a4819f..a8dd5a7 100755 --- a/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py +++ b/batchgenerators/dataloading/nondet_multi_threaded_augmenter.py @@ -34,11 +34,14 @@ except ImportError: torch = None +from threading import Lock +global_mutex = Lock() def producer(queue: Queue, data_loader, transform, thread_id: int, seed, abort_event: Event, wait_time: float = 0.02): # the producer will set the abort event if something happens - with threadpool_limits(1, None): + # with threadpool_limits(1, None): + global_mutex.acquire() # using instead of threadpool_limits for thread safety np.random.seed(seed) data_loader.set_thread_id(thread_id) item = None @@ -72,7 +75,9 @@ def producer(queue: Queue, data_loader, transform, thread_id: int, seed, traceback.print_exc() abort_event.set() return - + + #End of mutex section + global_mutex.release() def pin_memory_of_all_eligible_items_in_dict(result_dict): for k in result_dict.keys(): @@ -276,4 +281,4 @@ def __del__(self): end = time() print(end - st) - mt._finish() \ No newline at end of file + mt._finish()