-
Notifications
You must be signed in to change notification settings - Fork 26.3k
batch_norm_cpu_inference for channel last #28982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
VitalyFedyunin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add tests
| /// directory. | ||
| template<typename scalar_t> | ||
| void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input, | ||
| void batch_normal_cpu_inference_collect_liner_and_constant_terms( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit: nobody call it batch_normal please rename to batch_norm
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
VitalyFedyunin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good. Still couple things do to:
- Tests
- Benchmarking with different shapes ( especially different channels = 1, 3, 10, 1000 )
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Need to fix lint issues in the python (see inline comments inside of DIFF)
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
| /// directory. | ||
| template<typename scalar_t> | ||
| void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input, | ||
| void batch_norm_cpu_inference_collect_liner_and_constant_terms( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: did you mean linear?
| output2 = m(input2).permute(0, 1, 3, 2) | ||
| # channels last case | ||
| input3 = input1.contiguous(memory_format=torch.channels_last) | ||
| for name, param in m.named_parameters(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this ends up doing anything -- because you passed affine=False, this module won't actually have any parameters.
It would be good to test the affine=True case as well, but then the dim() == 4 won't pass (it's 1 dimensional).
Also, there's no need to call data.dim() -- it can make sense to assign to .data, but for function calls you are just creating a temporary.
| e1.fill_diagonal_(v, wrap=True) | ||
| self.assertEqual(e1, e2) | ||
|
|
||
| def test_batch_norm_cpu_inference(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be good to test this on float32s and float64s. You can search around the test framework for examples of that.
But since you are specifying magic values you should check that the magic works on all the combinations we care about.
|
@glaringlee merged this pull request in f2a35db. |
channels last version for batch_norm_cpu_inference_contiguous
Benchmark:
The benchmark test uses a fixed batch size n=20, channel number in [1,3,10,100,1000], height and width size in [1,4,16,64,256], height and width size are always the same in this test.
We use the following code to do this benchmark.
It tests contiguous, channels last and non-contiguous tensor in each loop and print out the benchmark. It also compare the outputs within each loop to make sure the correctness of the new change.
Sample output:
Benchmark n=20 c=100 h=256 w=256 -> title line
101 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> contiguous tensor
100 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each) -> channels last tensor
1.3 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -> non-contiguous tensor
True True -> 1st output compare with 2nd output, 1st output compare 3rd output, expect True
Benchmark Before this change:
Benchmark n=20 c=1 h=1 w=1
10.1 µs ± 158 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.2 µs ± 305 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.7 µs ± 784 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=4 w=4
10.2 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.1 µs ± 98 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.5 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=16 w=16
11 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 148 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
17.3 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=64 w=64
24.2 µs ± 536 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
23.9 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=1 h=256 w=256
539 µs ± 7.85 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
539 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.42 ms ± 33 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=3 h=1 w=1
10 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.97 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.4 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=4 w=4
10.4 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
16.1 µs ± 601 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
19.1 µs ± 658 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=16 w=16
13.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
25.3 µs ± 558 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32.4 µs ± 625 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=3 h=64 w=64
51.1 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
159 µs ± 7.3 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
199 µs ± 1.88 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=3 h=256 w=256
1.25 ms ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
2.95 ms ± 203 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.14 ms ± 42.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=10 h=1 w=1
9.97 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.5 µs ± 852 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11.7 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=4 w=4
11.2 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
29.7 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
39.4 µs ± 396 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=16 w=16
19.7 µs ± 632 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68.3 µs ± 912 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
90.3 µs ± 4.76 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=64 w=64
325 µs ± 5.01 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
918 µs ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
991 µs ± 44.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=10 h=256 w=256
9.47 ms ± 73.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
34.7 ms ± 2.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
91.5 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=1 w=1
11.8 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.1 µs ± 800 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12 µs ± 533 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=100 h=4 w=4
26.7 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
231 µs ± 8.03 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
335 µs ± 15.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=16 w=16
178 µs ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
1.45 ms ± 187 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.52 ms ± 94.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=64 w=64
6.9 ms ± 554 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
30.3 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
27 ms ± 272 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=256 w=256
98.9 ms ± 818 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.29 s ± 12.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.32 s ± 9.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=1 w=1
18.6 µs ± 2.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
18.7 µs ± 947 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
15.8 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1000 h=4 w=4
111 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2.07 ms ± 22.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.19 ms ± 163 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=1000 h=16 w=16
3.87 ms ± 336 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
25.6 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
27 ms ± 410 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=1000 h=64 w=64
70.1 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
467 ms ± 26.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
444 ms ± 25.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=256 w=256
2.39 s ± 19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
19.2 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
22.1 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark After this change:
Benchmark n=20 c=1 h=1 w=1
10.4 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.5 µs ± 149 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.7 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=4 w=4
11.8 µs ± 1.44 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
13.6 µs ± 142 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=16 w=16
11.9 µs ± 198 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.1 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
18.2 µs ± 205 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1 h=64 w=64
27.6 µs ± 2.4 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32.2 µs ± 8.69 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68.9 µs ± 1.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=1 h=256 w=256
601 µs ± 49 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
597 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.48 ms ± 24.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=3 h=1 w=1
10.8 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.6 µs ± 194 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.5 µs ± 137 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=4 w=4
11.6 µs ± 551 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11.7 µs ± 266 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
19.9 µs ± 340 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=3 h=16 w=16
13.7 µs ± 223 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
24.7 µs ± 424 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33.7 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=3 h=64 w=64
53.3 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
212 µs ± 4.68 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
204 µs ± 5.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=3 h=256 w=256
1.49 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.27 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.08 ms ± 290 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=10 h=1 w=1
10.7 µs ± 166 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.8 µs ± 225 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
10.8 µs ± 192 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=10 h=4 w=4
11.6 µs ± 129 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.9 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
43.7 µs ± 3.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=16 w=16
20.7 µs ± 576 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
37.2 µs ± 795 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
92.5 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
True True
Benchmark n=20 c=10 h=64 w=64
342 µs ± 9.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
622 µs ± 37.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.03 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=10 h=256 w=256
9.49 ms ± 130 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.9 ms ± 408 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
90.5 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=1 w=1
12 µs ± 575 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
11 µs ± 182 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=100 h=4 w=4
22.3 µs ± 451 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
18.7 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
323 µs ± 6.22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=16 w=16
211 µs ± 22 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
222 µs ± 20.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.5 ms ± 59.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
True True
Benchmark n=20 c=100 h=64 w=64
7.2 ms ± 1e+03 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
6.51 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
27.4 ms ± 695 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=100 h=256 w=256
101 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
100 ms ± 898 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.3 s ± 10.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=1 w=1
16.9 µs ± 589 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
16.5 µs ± 113 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
16.5 µs ± 168 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
True True
Benchmark n=20 c=1000 h=4 w=4
116 µs ± 6.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
67 µs ± 1.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3.23 ms ± 80 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
True True
Benchmark n=20 c=1000 h=16 w=16
3.53 ms ± 72.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.53 ms ± 125 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
27 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
True True
Benchmark n=20 c=1000 h=64 w=64
68.6 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
68 ms ± 288 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
425 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
Benchmark n=20 c=1000 h=256 w=256
2.51 s ± 97.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.84 s ± 471 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
21.5 s ± 933 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
True True
The channel last batch normalization is getting faster with this change and the previous existing code/logic is not affected based on the benchmark above.