-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add support for batch_norm fusion to the JIT #15146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@apaszke there are some AutodiffSubgraphSlicing tests failing. I'm not sure if those are related to your PR |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(still reading through, not a full review yet)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we chunk the output of batchnorm, then the chunk wouldn't get moved past the pointwise ops of the batchnorm because there isn't an opportunity to decompose the batchnorm, right? I'm not sure if people do this in practice though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It wouldn't, but I really don't expect this to be the common case, and I don't want to make the code for chunk more complicated for no good reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: For some definitions of Fusable, cat and chunk nodes are also fusible, so the naming of this function (isFusable) bothers me a little. Maybe call it something like "isDecomposibleIntoFusibleMap"? (although not all of batchnorm is decomposible into pointwise ops, only the last piece of it is...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree our definition of "fusable" is completely messed up, but clearing that up is a material for another PR. I simply followed whatever we were using.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work? Doesn't Tensor? mean that the tensor is either defined or an undefined tensor? It's strange that we can write Optional[Tensor] in torchscript for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, Optional[Tensor] means that it can be undefined on the C++ side 😕
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yikes, thanks for the clarification
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does ncf stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: _ncf_reshape might be a better name because we are reshaping instead of expanding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's as in NCHW, but it works with any dimension hence F for features instead of HW. I can rename it to reshape if you'd really want to, but I wouldn't like to block this PR on it if that's the only problem.
test/test_jit.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add to TestFuser maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, will move.
torch/csrc/jit/fuser/codegen.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does threshold play a role in the batchnorm fusion?
Should this be ${0} <= ${1}? (not sure if this makes a difference)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also probably add a correctness test for this to TestFuser
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's an orthogonal improvement to the fuser. Lets us fuse whole blocks between convs in ResNets.
test/test_jit.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a correctness test that runs with the JIT and checks the output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, will do.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BN changes look fine, I had some minor questions and comments, please read
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
|
@apaszke looks like you might have to rebase this |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@apaszke the test tolerance or the magnitude of the inputs might need to be updated: |
|
ROCm failures look unrelated |
|
I forgot to add CPU implementation for |
|
Red jobs are CI failures. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
momentum is unused in this function; it's only necessary for updating the running stats
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unused now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: epsilon not used, maybe remove the variable name?
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. There are some unused variables, let me know if you want to land this as-is or if you want to clean them up first
|
@apaszke this needs a rebase now |
We don't support reductions yet, but simply decomposing batch_norm into a kernel that computes the stats, and the fusing everything else with ReLU and following pointwise ops provides nice speedups. Note that this is only limited to inference mode for now, because we don't support convolutions and batch norm in AD, so the fuser isn't applied to those parts.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: We don't support reductions yet, but simply decomposing batch_norm into a kernel that computes the stats, and the fusing everything else with ReLU and following pointwise ops provides nice speedups. Note that this is only limited to inference mode for now, because we don't support convolutions and batch norm in AD, so the fuser isn't applied to those parts. This commit gives us a 7% end-to-end speedup for ResNet50 with batch size 32. Note that this only applies to inference mode at the moment due to lack of AD support for CNN operations (I'll be adding that soon), and not to the standard `torchvision` models, because they use in-place ops which aren't supported by the fuser (we need a way of proving that de-inplacing them is safe). cc zou3519 zdevito mruberry ngimel Pull Request resolved: pytorch/pytorch#15146 Differential Revision: D13548303 Pulled By: zou3519 fbshipit-source-id: a2e2e5abc383f637fae19bd1b423f20c2cbc056a
|
@apaszke Is there any update on AD support for batch norm for using fused batch_norm on training time? |
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.
Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
This commit gives us a 7% end-to-end speedup for ResNet50 with batch size 32. Note that this only applies to inference mode at the moment due to lack of AD support for CNN operations (I'll be adding that soon), and not to the standard
torchvisionmodels, because they use in-place ops which aren't supported by the fuser (we need a way of proving that de-inplacing them is safe).cc @zou3519 @zdevito @mruberry @ngimel