-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Add Custom graph fusion #18588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commented on the other WIP TVM IR with my fuser comments. I think they still apply here.
torch/csrc/jit/ir.h
Outdated
| TypePtr typ); // value of None with type Optional[typ] | ||
| TORCH_API Node* createAutogradZero(); | ||
| TORCH_API Node* createFusionGroup(); | ||
| TORCH_API Node* createFusionGroup(Symbol kind = prim::FusionGroup); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is no longer specific to FusionGroup, maybe it would be better to rename it to createWithSubgraph(Symbol kind)?
| Block* block_; | ||
| std::unique_ptr<AliasDb> aliasDb_; | ||
| std::shared_ptr<Graph> graph_; | ||
| using FusionCallback = std::function<bool(Node*)>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we move this to not appear inside a sequence of member decalarations? It is quite confusing at the moment
| std::unique_ptr<AliasDb> aliasDb_; | ||
| std::shared_ptr<Graph> graph_; | ||
| using FusionCallback = std::function<bool(Node*)>; | ||
| FusionCallback callback_ = [&](Node* n) { return isFusableDefault(n); }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm can you really capture this in a lambda that's in a default initializer of a member? It's a surprising syntax.
| } | ||
|
|
||
| void CustomFuseGraph(std::shared_ptr<Graph>& graph, std::function<bool(Node*)> fn, Symbol kind) { | ||
| if (canFuseOnCPU() || canFuseOnGPU()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My previous comment still applies: why do we check the capabilities of PyTorch fuser, when you're really applying this pass to obtain a fused node that will be passed through a completely different backend.
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
|
clang-tidy is complaining |
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
torch/csrc/jit/passes/graph_fuser.h
Outdated
|
|
||
| TORCH_API void CustomFuseGraph( | ||
| std::shared_ptr<Graph>& graph, | ||
| std::function<bool(Node*)> fn, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to call this is_fusable than fn because that name carries no information.
torch/csrc/jit/passes/graph_fuser.h
Outdated
| TORCH_API void CustomFuseGraph( | ||
| std::shared_ptr<Graph>& graph, | ||
| std::function<bool(Node*)> fn, | ||
| Symbol tag); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: kind
| bool isFusableDefault(Node* node) { | ||
| bool fusableDevice = true; | ||
| for (const auto& output : node->outputs()) { | ||
| fusableDevice &= isFusableDevice(output); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please check this only for outputs which have uses.
| producer->node()->matches( | ||
| "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) { | ||
|
|
||
| if (kind_ == prim::FusionGroup && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This place is a bit meh, but I guess it's ok. It's a good sign that we might want to rethink the API you're adding because right now there's no way to do decompose ops lazily with custom fusions.
|
@pytorchbot retest this please |
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
[jit] Add Custom graph fusion gh-metadata: pytorch pytorch 18588 gh/bwasti/2/head
Stack from ghstack:
Differential Revision: D14901297