Skip to content

Conversation

@syed-ahmed
Copy link
Collaborator

@syed-ahmed syed-ahmed commented Jun 3, 2019

Stack from ghstack:

Resubmit of #20626

Effective Bandwidth Benchmark

Float Type

Before:

bernoulli, size, elements 65536 forward 5.810260772705078e-06 bandwidth (GB/s) 45.117424200902754
bernoulli, size, elements 131072 forward 5.700588226318359e-06 bandwidth (GB/s) 91.97085970522794
bernoulli, size, elements 262144 forward 7.650852203369141e-06 bandwidth (GB/s) 137.0534905298847
bernoulli, size, elements 524288 forward 1.1038780212402343e-05 bandwidth (GB/s) 189.98041084682507
bernoulli, size, elements 1048576 forward 1.817464828491211e-05 bandwidth (GB/s) 230.77772588765578
bernoulli, size, elements 2097152 forward 3.152847290039063e-05 bandwidth (GB/s) 266.06451972800966
bernoulli, size, elements 4194304 forward 5.8722496032714846e-05 bandwidth (GB/s) 285.7033868358262
bernoulli, size, elements 8388608 forward 0.0001120924949645996 bandwidth (GB/s) 299.3459286511284
bernoulli, size, elements 16777216 forward 0.0002196049690246582 bandwidth (GB/s) 305.58900510336235
bernoulli, size, elements 33554432 forward 0.0004137754440307617 bandwidth (GB/s) 324.3733525907877

After:

bernoulli, size, elements 65536 forward 5.7387351989746094e-06 bandwidth (GB/s) 45.679751881013715
bernoulli, size, elements 131072 forward 5.600452423095703e-06 bandwidth (GB/s) 93.61529397837378
bernoulli, size, elements 262144 forward 6.201267242431641e-06 bandwidth (GB/s) 169.09060019623223
bernoulli, size, elements 524288 forward 6.272792816162109e-06 bandwidth (GB/s) 334.3250863629039
bernoulli, size, elements 1048576 forward 8.275508880615235e-06 bandwidth (GB/s) 506.83336342310577
bernoulli, size, elements 2097152 forward 1.2857913970947266e-05 bandwidth (GB/s) 652.4081603714445
bernoulli, size, elements 4194304 forward 2.348184585571289e-05 bandwidth (GB/s) 714.4760298270282
bernoulli, size, elements 8388608 forward 4.356622695922851e-05 bandwidth (GB/s) 770.1936647257047
bernoulli, size, elements 16777216 forward 8.656024932861328e-05 bandwidth (GB/s) 775.2850126994326
bernoulli, size, elements 33554432 forward 0.0001675891876220703 bandwidth (GB/s) 800.8734328534002

Double Type

Before:

bernoulli, size, elements 65536 forward 5.733966827392578e-06 bandwidth (GB/s) 45.717739200665285
bernoulli, size, elements 131072 forward 6.6208839416503905e-06 bandwidth (GB/s) 79.18700956254952
bernoulli, size, elements 262144 forward 1.0859966278076171e-05 bandwidth (GB/s) 96.55425929975851
bernoulli, size, elements 524288 forward 1.7333030700683594e-05 bandwidth (GB/s) 120.99165092445668
bernoulli, size, elements 1048576 forward 3.1557083129882816e-05 bandwidth (GB/s) 132.91165038090057
bernoulli, size, elements 2097152 forward 5.902767181396485e-05 bandwidth (GB/s) 142.11314358523305
bernoulli, size, elements 4194304 forward 0.00011337995529174805 bandwidth (GB/s) 147.9733869785806
bernoulli, size, elements 8388608 forward 0.00022054195404052734 bandwidth (GB/s) 152.14534643070206
bernoulli, size, elements 16777216 forward 0.0004380941390991211 bandwidth (GB/s) 153.18366079491483
bernoulli, size, elements 33554432 forward 0.0008704972267150879 bandwidth (GB/s) 154.1851299245198

After:

bernoulli, size, elements 65536 forward 5.877017974853515e-06 bandwidth (GB/s) 44.60493418969575
bernoulli, size, elements 131072 forward 5.819797515869141e-06 bandwidth (GB/s) 90.08698302138468
bernoulli, size, elements 262144 forward 6.091594696044922e-06 bandwidth (GB/s) 172.1348928025049
bernoulli, size, elements 524288 forward 8.232593536376953e-06 bandwidth (GB/s) 254.73770698546193
bernoulli, size, elements 1048576 forward 1.3000965118408203e-05 bandwidth (GB/s) 322.6148183461581
bernoulli, size, elements 2097152 forward 2.2871494293212892e-05 bandwidth (GB/s) 366.7713133413114
bernoulli, size, elements 4194304 forward 4.316329956054687e-05 bandwidth (GB/s) 388.69169342501107
bernoulli, size, elements 8388608 forward 8.46099853515625e-05 bandwidth (GB/s) 396.5776835981966
bernoulli, size, elements 16777216 forward 0.00016601085662841796 bandwidth (GB/s) 404.2438269577137
bernoulli, size, elements 33554432 forward 0.00031869888305664063 bandwidth (GB/s) 421.14276244936264

Differential Revision: D15632935

Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 21300 gh/syed-ahmed/12/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 21300 gh/syed-ahmed/12/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 21300 gh/syed-ahmed/12/head
@zou3519 zou3519 deleted the gh/syed-ahmed/12/head branch June 5, 2019 02:16
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in eadac84.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jun 5, 2019
Summary:
Pull Request resolved: pytorch/pytorch#21300
ghimport-source-id: c314c28cb693b554d6f24de235c11ba24ed6bf61

Reviewed By: jerryzh168

Differential Revision: D15632935

Pulled By: ezyang

fbshipit-source-id: 9bb24f17d78151bf50942905c967bdcfe1ff00cb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants