-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Remove curandStateMTGP32 usage #20886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ngimel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good riddance, mtgp!
aten/src/THC/THCTensorRandom.cuh
Outdated
| // search due to divergence. It seems possible to compute multiple | ||
| // values and limit divergence though later on. However, no matter | ||
| // what, all block threads must participate in the curand_uniform | ||
| // call to update the generator state. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is no longer valid (w/o mtgp, individual threads can participate in rng call)
| // The warp determines the sample | ||
| int sample = sampleBase + threadIdx.y; | ||
|
|
||
| // All threads participate in this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another invalid comment
aten/src/THCUNN/generic/RReLU.cu
Outdated
| // each thread will utilize one random, however, since we have to use | ||
| // curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]), | ||
| // offset is 4. | ||
| uint64_t offset = gen->state.philox_seed_offset.fetch_add(4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that NUM_BLOCKS for most cases will be set to 64 (that's a poor choice, but for the next PR), so you'll have a grid-stride loop inside the kernel and generate multiple randoms, so adjust offset accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated offset calc with (numel / block_size * grid.x) * 4.
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
@syed-ahmed could you rebase this stack on master? (I can do it myself, but if I do you'll have to force update your own local branch pointer--let me know if you'd prefer me to do it) |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
@ezyang rebased :). |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
Sorry, you rebased on top of broken master. Once the breakage is reverted we'll need another rebase :/ |
|
A little more text in the PR description would have been appreciated for this poor reviewer ^^ |
| THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size"); | ||
| THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); | ||
| THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), gen->state.gen_states, | ||
| states_size, cudaMemcpyDeviceToHost)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be a good idea to fill in this memory with deterministic garbage so if someone tries to use it (improperly) it won't be a random error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filled in the memory with -1 and verified locally that torch.cuda.get_rng_state() gives 255 in the first few elements.
| THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous"); | ||
|
|
||
| THCudaCheck(cudaMemcpy(gen->state.gen_states, THByteTensor_data(rng_state), | ||
| states_size, cudaMemcpyHostToDevice)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? Since I made all the gen_states memory to have -1 in getRNGState, this function will just not affect that value. If I were to do cudaMemcpy or memset here, I need to allocate the gen_states (which I deleted i.e. the initializeGenerator function).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, please don't do that :) This can be kept as is (I just saw something that looked similar to the previous pattern.)
| template <typename T> | ||
| __global__ void | ||
| sampleMultinomialWithReplacement(curandStateMtgp32* state, | ||
| sampleMultinomialWithReplacement(std::pair<uint64_t, uint64_t> seeds, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be fair, the second element of this pair isn't really a seed, it's an offset, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. Little hand wavy here I agree. But you could interpret it as, since seed decides where a rng sequence starts from, the offset just gives a finer control over it for the philox engine. So seed for philox could be an umbrella term for the actual seed value + offset 🤷♂️ . If you want I can change the name (seed_and_offset maybe?), but then we should be changing the variable name every where and use it like seed_and_offset.first, seed_and_offset.second.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, if you like it, let's keep it :)
|
How can I tell if the offset calculations were done right? Do tests cover this at all? It seems very fiddly. |
The philox offset calculation for RRelu.cu should be good, since it runs the exact same way as the kernels tested in |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
|
Sorry, we need another rebase; master was a disaster yesterday. |
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Stack from ghstack:
Differential Revision: D15535503
Summary:
This PR removes curandStateMTGP32 usages since it's not stream-safe.
Main changes are: