Skip to content

Apply fusion more aggressively in NAdam and Adagrad compilation #107006

@mlazos

Description

@mlazos

🚀 The feature, motivation and pitch

Currently NAdam compiles to > 5 kernels and Adagrad does as well. See

for current kernel counts. Ideally these should fully fuse, but likely there are issues due to the presence of mutation. A good place to start is

import itertools

Specifically

def can_fuse(cls, producer, consumer):

Ideally the rules should be able to be modified to soundly allow these to fuse fully.

Contact @mlazos for more details

Alternatives

No response

Additional context

No response

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @wconstab @Xia-Weiwen @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixgood first issuemodule: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions