Skip to content

Conversation

@apaszke
Copy link
Contributor

@apaszke apaszke commented Dec 12, 2018

We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.

This commit gives us a 7% end-to-end speedup for ResNet50 with batch size 32. Note that this only applies to inference mode at the moment due to lack of AD support for CNN operations (I'll be adding that soon), and not to the standard torchvision models, because they use in-place ops which aren't supported by the fuser (we need a way of proving that de-inplacing them is safe).

cc @zou3519 @zdevito @mruberry @ngimel

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Dec 12, 2018
@zou3519
Copy link
Contributor

zou3519 commented Dec 14, 2018

@apaszke there are some AutodiffSubgraphSlicing tests failing. I'm not sure if those are related to your PR

@fmassa fmassa mentioned this pull request Dec 17, 2018
@apaszke apaszke closed this Dec 19, 2018
@apaszke apaszke reopened this Dec 19, 2018
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(still reading through, not a full review yet)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we chunk the output of batchnorm, then the chunk wouldn't get moved past the pointwise ops of the batchnorm because there isn't an opportunity to decompose the batchnorm, right? I'm not sure if people do this in practice though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wouldn't, but I really don't expect this to be the common case, and I don't want to make the code for chunk more complicated for no good reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For some definitions of Fusable, cat and chunk nodes are also fusible, so the naming of this function (isFusable) bothers me a little. Maybe call it something like "isDecomposibleIntoFusibleMap"? (although not all of batchnorm is decomposible into pointwise ops, only the last piece of it is...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree our definition of "fusable" is completely messed up, but clearing that up is a material for another PR. I simply followed whatever we were using.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work? Doesn't Tensor? mean that the tensor is either defined or an undefined tensor? It's strange that we can write Optional[Tensor] in torchscript for that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, Optional[Tensor] means that it can be undefined on the C++ side 😕

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes, thanks for the clarification

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ncf stand for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: _ncf_reshape might be a better name because we are reshaping instead of expanding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's as in NCHW, but it works with any dimension hence F for features instead of HW. I can rename it to reshape if you'd really want to, but I wouldn't like to block this PR on it if that's the only problem.

test/test_jit.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to TestFuser maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, will move.

Copy link
Contributor

@zou3519 zou3519 Dec 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does threshold play a role in the batchnorm fusion?

Should this be ${0} <= ${1}? (not sure if this makes a difference)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also probably add a correctness test for this to TestFuser

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an orthogonal improvement to the fuser. Lets us fuse whole blocks between convs in ResNets.

test/test_jit.py Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a correctness test that runs with the JIT and checks the output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, will do.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BN changes look fine, I had some minor questions and comments, please read

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@zou3519
Copy link
Contributor

zou3519 commented Dec 26, 2018

@apaszke looks like you might have to rebase this

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519
Copy link
Contributor

zou3519 commented Dec 27, 2018

@apaszke the test tolerance or the magnitude of the inputs might need to be updated:

Dec 27 18:20:34 ======================================================================
Dec 27 18:20:34 FAIL: test_fuse_batch_norm (__main__.TestFuser)
Dec 27 18:20:34 ----------------------------------------------------------------------
Dec 27 18:20:34 Traceback (most recent call last):
Dec 27 18:20:34   File "test_jit.py", line 10221, in test_fuse_batch_norm
Dec 27 18:20:34     self.assertEqual(out, out_noopt)
Dec 27 18:20:34   File "/var/lib/jenkins/workspace/test/common_utils.py", line 418, in assertEqual
Dec 27 18:20:34     assertTensorsEqual(x, y)
Dec 27 18:20:34   File "/var/lib/jenkins/workspace/test/common_utils.py", line 410, in assertTensorsEqual
Dec 27 18:20:34     self.assertLessEqual(max_err, prec, message)
Dec 27 18:20:34 AssertionError: tensor(1.2243e-05, device='cuda:0') not less than or equal to 1e-05 : 
Dec 27 18:20:34 

@apaszke
Copy link
Contributor Author

apaszke commented Dec 30, 2018

ROCm failures look unrelated

@apaszke
Copy link
Contributor Author

apaszke commented Dec 30, 2018

I forgot to add CPU implementation for batch_norm_update_stats, which is fixed in the latest commit. Please review the ATen changes again!

@apaszke
Copy link
Contributor Author

apaszke commented Jan 2, 2019

Red jobs are CI failures.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

momentum is unused in this function; it's only necessary for updating the running stats

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unused now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: epsilon not used, maybe remove the variable name?

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. There are some unused variables, let me know if you want to land this as-is or if you want to clean them up first

@zou3519
Copy link
Contributor

zou3519 commented Jan 3, 2019

@apaszke this needs a rebase now

We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jan 8, 2019
Summary:
We don't support reductions yet, but simply decomposing batch_norm
into a kernel that computes the stats, and the fusing everything else
with ReLU and following pointwise ops provides nice speedups.

Note that this is only limited to inference mode for now, because we
don't support convolutions and batch norm in AD, so the fuser isn't
applied to those parts.

This commit gives us a 7% end-to-end speedup for ResNet50 with batch size 32. Note that this only applies to inference mode at the moment due to lack of AD support for CNN operations (I'll be adding that soon), and not to the standard `torchvision` models, because they use in-place ops which aren't supported by the fuser (we need a way of proving that de-inplacing them is safe).

cc zou3519 zdevito mruberry ngimel
Pull Request resolved: pytorch/pytorch#15146

Differential Revision: D13548303

Pulled By: zou3519

fbshipit-source-id: a2e2e5abc383f637fae19bd1b423f20c2cbc056a
@apaszke apaszke mentioned this pull request Jan 9, 2019
facebook-github-bot pushed a commit that referenced this pull request Jan 10, 2019
Summary:
Resubmit of #15146, which has been accidentally reverted.
Pull Request resolved: #15897

Differential Revision: D13616093

Pulled By: zou3519

fbshipit-source-id: 0c3a3bec8f9fed57274da9f6c7cf40cbc05cf91a
@chanil1218
Copy link

@apaszke
Thank you for releasing nice code!
What I understood is that JIT fused batch_norm should only be used at inference time for speed up purpose. Due to the reason that batch norm is not supported in AD.

Is there any update on AD support for batch norm for using fused batch_norm on training time?
Or related issue that I could keep track of the progress of that support?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants