-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Spruce up docs for emulate_precision_casts #145579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145579
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 77abf9b with merge base d6bea39 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Signed-off-by: Edward Z. Yang <[email protected]> ghstack-source-id: 28df7aa Pull Request resolved: #145579
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| # For multiple, fused pointwise nodes, inductor will elide the intermediary upcasts and downcasts | ||
| # Typically this should be closer to fp64 ref numerics. However, it can be useful for debugging | ||
| # to emulate the eager numerics. | ||
| # Mode to emulate PyTorch eager numerics when doing lower precision compute |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to what extent do you think we should have a flag (this one or another) for inductor to emulate all eager numerics? In particular, even if there is a potential perf cost?
One that came up recently was var_mean, which it looks like inductor lowers in a way that gives slightly different (but more accurate?) numerics:
import torch
torch._inductor.config.emulate_precision_casts = True
class GraphModule(torch.nn.Module):
def forward(self, inp):
return torch.ops.aten.var_mean.correction(inp, [3], correction = 0, keepdim = True)
from torch._dynamo.testing import rand_strided
inp = rand_strided((1, 256, 256, 144,), (9437184, 36864, 144, 1,), device='cuda:0', dtype=torch.float32)
m = GraphModule()
out_eager = m(inp)
out_ref = m(inp.to(dtype=torch.float64))
out_compile = torch.compile(m)(inp)
print(torch.allclose(out_eager[0], out_compile[0]))
print(torch.allclose(out_ref[0], out_compile[0].to(dtype=torch.float64)))
print(torch.allclose(out_ref[0], out_eager[0].to(dtype=torch.float64)))
print(torch.max(torch.abs(out_eager[0] - out_compile[0])))
print(torch.max(torch.abs(out_ref[0] - out_compile[0])))
print(torch.max(torch.abs(out_ref[0] - out_eager[0])))
# prints:
True
True
True
tensor(3.5763e-07, device='cuda:0')
tensor(2.3738e-07, device='cuda:0', dtype=torch.float64)
tensor(2.7881e-07, device='cuda:0', dtype=torch.float64)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be very useful lol
Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#145579 Approved by: https://github.com/gchanan
| # and downcasting after. When two low precision operators are fused together, | ||
| # Inductor will elide the downcast-upcast pairs (effectively a precision | ||
| # truncation) that would occur between these two operators. Typically, | ||
| # Inductor's behavior should be closer to fp64 ref numerics. However, with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we're going to expand maybe we should actually have a doc somewhere instead of just a longer config str
Stack from ghstack (oldest at bottom):
Signed-off-by: Edward Z. Yang [email protected]
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov