-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop #20626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
After these PRs, will the CUDA RNGs be thread and stream-safe in PyTorch? Does the |
| if (std::is_same<scalar_t, double>::value) { | ||
| distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter, | ||
| gen, | ||
| [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); }, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so @ngimel told me that curand's uniform with double is actually a lie because there is not enough precision. I don't know if things have changed or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How many bits of randomness does it provide? 53 bits is standard so that it generates all the rationals x / 2^53 for x in [0,1, ..., 2^53-1].
That doesn't generate all valid doubles in [0, 1), but basically nobody does that outside of toy programs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curand_uniform_double is a lie (uses 32 bits), curand_uniform2_double is not (uses 53 bits per value).
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
That is the intention. These PRs achieve thread and stream-safety by replacing curandStateMTGP with curandStatePhilox: #19508 (comment)
I checked locally and got the following ouput: So may be it fixed the issue? I didn't triage and find the source of that bug. May be the tensor iterator magic in the normal PR has fixed this. |
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Speedup bernoulli_scalar_cuda_kernel with grid-stride loop gh-metadata: pytorch pytorch 20626 gh/syed-ahmed/7/head
Stack from ghstack:
Differential Revision: D15454046
Effective Bandwidth Benchmark
Float Type
Before:
After:
Double Type
Before:
After: