Skip to content

Conversation

@syed-ahmed
Copy link
Collaborator

@syed-ahmed syed-ahmed commented May 23, 2019

Stack from ghstack:

Differential Revision: D15535503

Summary:

This PR removes curandStateMTGP32 usages since it's not stream-safe.
Main changes are:

  • It modifies THCTensor_(getRNGState) and THCTensor_(setRNGState) to not read/write curandStateMTGP anymore.
  • It modifies RRelu.cu and cuda multinomial kernels to use curandStatePhilox
  • It deletes new_state.clone() from torch.cuda.random.py to get a performance boost.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good riddance, mtgp!

// search due to divergence. It seems possible to compute multiple
// values and limit divergence though later on. However, no matter
// what, all block threads must participate in the curand_uniform
// call to update the generator state.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is no longer valid (w/o mtgp, individual threads can participate in rng call)

// The warp determines the sample
int sample = sampleBase + threadIdx.y;

// All threads participate in this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another invalid comment

// each thread will utilize one random, however, since we have to use
// curand_uniform4 (See Note [Register spilling in curand call for CUDA < 10]),
// offset is 4.
uint64_t offset = gen->state.philox_seed_offset.fetch_add(4);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that NUM_BLOCKS for most cases will be set to 64 (that's a poor choice, but for the next PR), so you'll have a grid-stride loop inside the kernel and generate multiple randoms, so adjust offset accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated offset calc with (numel / block_size * grid.x) * 4.

Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
@ezyang
Copy link
Contributor

ezyang commented May 29, 2019

@syed-ahmed could you rebase this stack on master? (I can do it myself, but if I do you'll have to force update your own local branch pointer--let me know if you'd prefer me to do it)

Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
@syed-ahmed
Copy link
Collaborator Author

@ezyang rebased :).

@syed-ahmed syed-ahmed requested a review from ezyang May 29, 2019 18:23
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
@ezyang
Copy link
Contributor

ezyang commented May 29, 2019

Sorry, you rebased on top of broken master. Once the breakage is reverted we'll need another rebase :/

@ezyang
Copy link
Contributor

ezyang commented May 29, 2019

A little more text in the PR description would have been appreciated for this poor reviewer ^^

THArgCheck(THByteTensor_nElement(rng_state) == total_size, 1, "RNG state is wrong size");
THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");
THCudaCheck(cudaMemcpy(THByteTensor_data(rng_state), gen->state.gen_states,
states_size, cudaMemcpyDeviceToHost));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to fill in this memory with deterministic garbage so if someone tries to use it (improperly) it won't be a random error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filled in the memory with -1 and verified locally that torch.cuda.get_rng_state() gives 255 in the first few elements.

THArgCheck(THByteTensor_isContiguous(rng_state), 1, "RNG state must be contiguous");

THCudaCheck(cudaMemcpy(gen->state.gen_states, THByteTensor_data(rng_state),
states_size, cudaMemcpyHostToDevice));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? Since I made all the gen_states memory to have -1 in getRNGState, this function will just not affect that value. If I were to do cudaMemcpy or memset here, I need to allocate the gen_states (which I deleted i.e. the initializeGenerator function).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, please don't do that :) This can be kept as is (I just saw something that looked similar to the previous pattern.)

template <typename T>
__global__ void
sampleMultinomialWithReplacement(curandStateMtgp32* state,
sampleMultinomialWithReplacement(std::pair<uint64_t, uint64_t> seeds,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, the second element of this pair isn't really a seed, it's an offset, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. Little hand wavy here I agree. But you could interpret it as, since seed decides where a rng sequence starts from, the offset just gives a finer control over it for the philox engine. So seed for philox could be an umbrella term for the actual seed value + offset 🤷‍♂️ . If you want I can change the name (seed_and_offset maybe?), but then we should be changing the variable name every where and use it like seed_and_offset.first, seed_and_offset.second.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, if you like it, let's keep it :)

@ezyang
Copy link
Contributor

ezyang commented May 29, 2019

How can I tell if the offset calculations were done right? Do tests cover this at all? It seems very fiddly.

@syed-ahmed
Copy link
Collaborator Author

How can I tell if the offset calculations were done right? Do tests cover this at all? It seems very fiddly.

The philox offset calculation for RRelu.cu should be good, since it runs the exact same way as the kernels tested in cuda_distributions_test.cu. I can add a test for multinomial.

Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
@pytorchbot pytorchbot added the module: internals Related to internal abstractions in c10 and ATen label May 29, 2019
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
@ezyang
Copy link
Contributor

ezyang commented May 30, 2019

Sorry, we need another rebase; master was a disaster yesterday.

Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
@syed-ahmed syed-ahmed requested a review from ezyang May 30, 2019 16:12
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Remove curandStateMTGP32 usage

gh-metadata: pytorch pytorch 20886 gh/syed-ahmed/8/head
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants