-
Notifications
You must be signed in to change notification settings - Fork 97
[torchlib] Fix aten__native_batch_norm_legit_functional
#2753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2753 +/- ##
=======================================
Coverage 70.09% 70.09%
=======================================
Files 228 228
Lines 27382 27382
Branches 2783 2783
=======================================
Hits 19194 19194
Misses 7229 7229
Partials 959 959 ☔ View full report in Codecov by Sentry. |
| running_mean_fp32 = op.Cast(running_mean, to=FLOAT.dtype) | ||
| invstd = op.Cast(invstd, to=FLOAT.dtype) | ||
| return norm, running_mean_fp32, invstd, running_mean, running_var | ||
| return norm, running_mean_fp32, invstd, op.Identity(running_mean), op.Identity(running_var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what was happening? I assume running_mean/var were valid values already ... but presumably some other requirement was violated (like an input-value cannot be output-value)?
In other words, what is the requirement for torchlib functions from a dev's perspective? Are they required to never return an input-value as an output-value without wrapping in an Identity? Seems like something that the underlying infrastructure could take care of without burdening the torchlib developer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the running_mean and running_var are treated as mutable buffers in the pytorch model. It is an initializer directly as a graph output.
This only happens with the training graph. So we did not really see the case in our testing.
From torchlib's perspective, yes an input should not be returned directly as output. It is probably true that we can detect that externally and wrap the output with an identity.
Fix aten__native_batch_norm_legit_functional where the running mean/var were returned without creating a new value, making the graph invalid.
Fixes pytorch/pytorch#171471