Adds functionality to populate torch generator using torch.thread_safe_generator#9371
Adds functionality to populate torch generator using torch.thread_safe_generator#9371divyanshk wants to merge 4 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9371
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Pending, 2 Unrelated FailuresAs of commit 6277f11 with merge base 48956e0 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
1e0225e to
b15da1c
Compare
NicolasHug
left a comment
There was a problem hiding this comment.
Thanks for the PR @divyanshk . I think the changes look reasonable.
One thing I'm wondering is how does this affect the multiprocess-based dataloaders? Currently, since TV is using the global torch RNG, that global generator will be seeded by torch using a different seed for each process/worker. This is the correct behavior since we want each worker to have a different RNG.
Is that behavior preserved now that we're using torch.thread_safe_generator()?
It'd be good to have tests ensure that's the case (both for multiprocess and multithreaded cases).
b15da1c to
18bdef3
Compare
|
The multiprocessing case remains unchanged because torch.thread_safe_generator will return None for multiprocessing use-case. So for MP, there is no change. Earlier the torch.rand functions received None for generator arg, and now they would get the same. Also added a test case where I confirm the expected behavior for multiprocessing. |
c416fad to
e7da958
Compare
e7da958 to
6277f11
Compare
| transforms.RandomPerspective(p=1.0), | ||
| transforms.RandomErasing(p=1.0), | ||
| transforms.ScaleJitter(target_size=(24, 24)), | ||
| ] |
There was a problem hiding this comment.
We have a few more random transforms in TV that we'll also want to update and test. I think the list you'll find in https://github.com/pytorch/vision/pull/7848/changes should have the proper coverage (but claude should be able to find all the relevant ones)
| assert not torch.equal(batch0, batch1) | ||
|
|
||
| @pytest.mark.parametrize("transform", TRANSFORMS, ids=lambda t: type(t).__name__) | ||
| def test_thread_worker_uses_thread_local_generator(self, transform): |
There was a problem hiding this comment.
For this multi-threading test, is there a way to test the actual multi-threaded behavior, without the mocking? I.e. ideally I'd like to test the public-facing APIs when a user requests multi-threaded from the DataLoader. I'm not sure what the public entry point is though?
Added thread-safe random number generation to all V2 torchvision random transforms to prevent race conditions when using DataLoader with thread-based workers (worker_method='thread').
This is based on
torch.thread_safe_generatorwhich returns dataloader thread-worker specific RNG or None otherwise.