-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Currently NAdam compiles to > 5 kernels and Adagrad does as well. See
| import unittest |
for current kernel counts. Ideally these should fully fuse, but likely there are issues due to the presence of mutation. A good place to start is
pytorch/torch/_inductor/scheduler.py
Line 4 in 4df84c3
| import itertools |
Specifically
pytorch/torch/_inductor/scheduler.py
Line 644 in 4df84c3
| def can_fuse(cls, producer, consumer): |
Ideally the rules should be able to be modified to soundly allow these to fuse fully.
Contact @mlazos for more details
Alternatives
No response
Additional context
No response
cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @wconstab @Xia-Weiwen @ngimel