Skip to content

Conversation

@syed-ahmed
Copy link
Collaborator

@syed-ahmed syed-ahmed commented May 17, 2019

Stack from ghstack:

Differential Revision: D15454046

Effective Bandwidth Benchmark

Float Type

Before:

bernoulli, size, elements 65536 forward 5.810260772705078e-06 bandwidth (GB/s) 45.117424200902754
bernoulli, size, elements 131072 forward 5.700588226318359e-06 bandwidth (GB/s) 91.97085970522794
bernoulli, size, elements 262144 forward 7.650852203369141e-06 bandwidth (GB/s) 137.0534905298847
bernoulli, size, elements 524288 forward 1.1038780212402343e-05 bandwidth (GB/s) 189.98041084682507
bernoulli, size, elements 1048576 forward 1.817464828491211e-05 bandwidth (GB/s) 230.77772588765578
bernoulli, size, elements 2097152 forward 3.152847290039063e-05 bandwidth (GB/s) 266.06451972800966
bernoulli, size, elements 4194304 forward 5.8722496032714846e-05 bandwidth (GB/s) 285.7033868358262
bernoulli, size, elements 8388608 forward 0.0001120924949645996 bandwidth (GB/s) 299.3459286511284
bernoulli, size, elements 16777216 forward 0.0002196049690246582 bandwidth (GB/s) 305.58900510336235
bernoulli, size, elements 33554432 forward 0.0004137754440307617 bandwidth (GB/s) 324.3733525907877

After:

bernoulli, size, elements 65536 forward 5.7387351989746094e-06 bandwidth (GB/s) 45.679751881013715
bernoulli, size, elements 131072 forward 5.600452423095703e-06 bandwidth (GB/s) 93.61529397837378
bernoulli, size, elements 262144 forward 6.201267242431641e-06 bandwidth (GB/s) 169.09060019623223
bernoulli, size, elements 524288 forward 6.272792816162109e-06 bandwidth (GB/s) 334.3250863629039
bernoulli, size, elements 1048576 forward 8.275508880615235e-06 bandwidth (GB/s) 506.83336342310577
bernoulli, size, elements 2097152 forward 1.2857913970947266e-05 bandwidth (GB/s) 652.4081603714445
bernoulli, size, elements 4194304 forward 2.348184585571289e-05 bandwidth (GB/s) 714.4760298270282
bernoulli, size, elements 8388608 forward 4.356622695922851e-05 bandwidth (GB/s) 770.1936647257047
bernoulli, size, elements 16777216 forward 8.656024932861328e-05 bandwidth (GB/s) 775.2850126994326
bernoulli, size, elements 33554432 forward 0.0001675891876220703 bandwidth (GB/s) 800.8734328534002

Double Type

Before:

bernoulli, size, elements 65536 forward 5.733966827392578e-06 bandwidth (GB/s) 45.717739200665285
bernoulli, size, elements 131072 forward 6.6208839416503905e-06 bandwidth (GB/s) 79.18700956254952
bernoulli, size, elements 262144 forward 1.0859966278076171e-05 bandwidth (GB/s) 96.55425929975851
bernoulli, size, elements 524288 forward 1.7333030700683594e-05 bandwidth (GB/s) 120.99165092445668
bernoulli, size, elements 1048576 forward 3.1557083129882816e-05 bandwidth (GB/s) 132.91165038090057
bernoulli, size, elements 2097152 forward 5.902767181396485e-05 bandwidth (GB/s) 142.11314358523305
bernoulli, size, elements 4194304 forward 0.00011337995529174805 bandwidth (GB/s) 147.9733869785806
bernoulli, size, elements 8388608 forward 0.00022054195404052734 bandwidth (GB/s) 152.14534643070206
bernoulli, size, elements 16777216 forward 0.0004380941390991211 bandwidth (GB/s) 153.18366079491483
bernoulli, size, elements 33554432 forward 0.0008704972267150879 bandwidth (GB/s) 154.1851299245198

After:

bernoulli, size, elements 65536 forward 5.877017974853515e-06 bandwidth (GB/s) 44.60493418969575
bernoulli, size, elements 131072 forward 5.819797515869141e-06 bandwidth (GB/s) 90.08698302138468
bernoulli, size, elements 262144 forward 6.091594696044922e-06 bandwidth (GB/s) 172.1348928025049
bernoulli, size, elements 524288 forward 8.232593536376953e-06 bandwidth (GB/s) 254.73770698546193
bernoulli, size, elements 1048576 forward 1.3000965118408203e-05 bandwidth (GB/s) 322.6148183461581
bernoulli, size, elements 2097152 forward 2.2871494293212892e-05 bandwidth (GB/s) 366.7713133413114
bernoulli, size, elements 4194304 forward 4.316329956054687e-05 bandwidth (GB/s) 388.69169342501107
bernoulli, size, elements 8388608 forward 8.46099853515625e-05 bandwidth (GB/s) 396.5776835981966
bernoulli, size, elements 16777216 forward 0.00016601085662841796 bandwidth (GB/s) 404.2438269577137
bernoulli, size, elements 33554432 forward 0.00031869888305664063 bandwidth (GB/s) 421.14276244936264

@colesbury
Copy link
Member

After these PRs, will the CUDA RNGs be thread and stream-safe in PyTorch?

Does the THCTensor_(normal) PR happen to fix #17898?

if (std::is_same<scalar_t, double>::value) {
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
gen,
[] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so @ngimel told me that curand's uniform with double is actually a lie because there is not enough precision. I don't know if things have changed or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many bits of randomness does it provide? 53 bits is standard so that it generates all the rationals x / 2^53 for x in [0,1, ..., 2^53-1].

That doesn't generate all valid doubles in [0, 1), but basically nobody does that outside of toy programs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curand_uniform_double is a lie (uses 32 bits), curand_uniform2_double is not (uses 53 bits per value).

Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
@syed-ahmed
Copy link
Collaborator Author

syed-ahmed commented May 21, 2019

@colesbury

After these PRs, will the CUDA RNGs be thread and stream-safe in PyTorch?

That is the intention. These PRs achieve thread and stream-safety by replacing curandStateMTGP with curandStatePhilox: #19508 (comment)

Does the THCTensor_(normal) PR happen to fix #17898?

I checked locally and got the following ouput:

Python 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34) 
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> x = torch.randn(65536, 32768, device='cuda')
>>> print(x)
tensor([[ 0.1168,  0.3966, -0.6651,  ..., -0.0352,  1.3890,  1.6474],
        [ 1.2551, -0.3203, -1.4046,  ..., -1.4736, -0.3075,  0.5145],
        [ 0.7321,  1.5310, -1.1256,  ..., -0.3323,  0.4313, -0.5451],
        ...,
        [ 0.1096, -0.4399, -0.3016,  ...,  0.5718,  0.7515, -0.1184],
        [-0.2639, -0.0976,  0.2093,  ...,  0.5791, -1.0212, -0.0307],
        [-0.2508, -0.0478,  1.4468,  ..., -0.5606,  0.5161,  2.1986]],
       device='cuda:0')

So may be it fixed the issue? I didn't triage and find the source of that bug. May be the tensor iterator magic in the normal PR has fixed this.

Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
@syed-ahmed syed-ahmed requested a review from ezyang May 29, 2019 20:10
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop

gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cuda Related to torch.cuda, and CUDA support in general open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants