diff --git a/intermediate_source/torch_compile_full_example.py b/intermediate_source/torch_compile_full_example.py index 7804d5b155..84e8f74e76 100644 --- a/intermediate_source/torch_compile_full_example.py +++ b/intermediate_source/torch_compile_full_example.py @@ -79,7 +79,7 @@ def timed(fn): # batch size. def generate_data(b): return ( - torch.randn(b, 3, 128, 128).to().cuda(), + torch.randn(b, 3, 128, 128).cuda(), torch.randint(1000, (b,)).cuda(), )