Skip to content

Conversation

@syed-ahmed
Copy link
Collaborator

@syed-ahmed syed-ahmed commented May 17, 2019

Stack from ghstack:

Effective Bandwidth Benchmark

Float Type

Before:

normal, size, elements 65536 forward 4.956722259521484e-06 bandwidth (GB/s) 52.88656218258779
normal, size, elements 131072 forward 5.285739898681641e-06 bandwidth (GB/s) 99.18914098114568
normal, size, elements 262144 forward 7.548332214355469e-06 bandwidth (GB/s) 138.91492454529376
normal, size, elements 524288 forward 1.1980533599853516e-05 bandwidth (GB/s) 175.0466273076219
normal, size, elements 1048576 forward 2.091646194458008e-05 bandwidth (GB/s) 200.52645667862762
normal, size, elements 2097152 forward 3.9961338043212894e-05 bandwidth (GB/s) 209.91809610901498
normal, size, elements 4194304 forward 7.39765167236328e-05 bandwidth (GB/s) 226.79110538115253
normal, size, elements 8388608 forward 0.0001377725601196289 bandwidth (GB/s) 243.5494555001696
normal, size, elements 16777216 forward 0.0002710080146789551 bandwidth (GB/s) 247.62686107087774
normal, size, elements 33554432 forward 0.0005375170707702637 bandwidth (GB/s) 249.69947058177252

After:

normal, size, elements 65536 forward 6.198883056640625e-06 bandwidth (GB/s) 42.288908760615385
normal, size, elements 131072 forward 6.756782531738281e-06 bandwidth (GB/s) 77.59432800112916
normal, size, elements 262144 forward 7.560253143310547e-06 bandwidth (GB/s) 138.6958849291706
normal, size, elements 524288 forward 7.550716400146485e-06 bandwidth (GB/s) 277.7421225831386
normal, size, elements 1048576 forward 1.1034011840820313e-05 bandwidth (GB/s) 380.1250225673293
normal, size, elements 2097152 forward 1.802682876586914e-05 bandwidth (GB/s) 465.34019427102237
normal, size, elements 4194304 forward 2.8417110443115234e-05 bandwidth (GB/s) 590.3913430460946
normal, size, elements 8388608 forward 4.8711299896240235e-05 bandwidth (GB/s) 688.8428777608927
normal, size, elements 16777216 forward 9.685993194580078e-05 bandwidth (GB/s) 692.8444265018856
normal, size, elements 33554432 forward 0.00018213510513305663 bandwidth (GB/s) 736.9130069787966

Double Type

Before:

normal, size, elements 65536 forward 5.8841705322265624e-06 bandwidth (GB/s) 44.55071425348461
normal, size, elements 131072 forward 8.018016815185547e-06 bandwidth (GB/s) 65.38873789925661
normal, size, elements 262144 forward 1.2989044189453124e-05 bandwidth (GB/s) 80.72772597474304
normal, size, elements 524288 forward 2.2075176239013673e-05 bandwidth (GB/s) 95.00046465285668
normal, size, elements 1048576 forward 4.1041374206542965e-05 bandwidth (GB/s) 102.19696784254678
normal, size, elements 2097152 forward 7.57598876953125e-05 bandwidth (GB/s) 110.72624650312186
normal, size, elements 4194304 forward 0.00013725996017456056 bandwidth (GB/s) 122.22949779865557
normal, size, elements 8388608 forward 0.0002614736557006836 bandwidth (GB/s) 128.32815569921402
normal, size, elements 16777216 forward 0.0005080199241638184 bandwidth (GB/s) 132.0988819689674
normal, size, elements 33554432 forward 0.0009479570388793945 bandwidth (GB/s) 141.58629821311564

After:

normal, size, elements 65536 forward 5.991458892822265e-06 bandwidth (GB/s) 43.75294977222444
normal, size, elements 131072 forward 7.293224334716797e-06 bandwidth (GB/s) 71.88699756626349
normal, size, elements 262144 forward 8.094310760498048e-06 bandwidth (GB/s) 129.54481623281296
normal, size, elements 524288 forward 1.2805461883544922e-05 bandwidth (GB/s) 163.7701177100726
normal, size, elements 1048576 forward 2.2592544555664064e-05 bandwidth (GB/s) 185.64991604491345
normal, size, elements 2097152 forward 3.801822662353516e-05 bandwidth (GB/s) 220.6470092112881
normal, size, elements 4194304 forward 6.761550903320313e-05 bandwidth (GB/s) 248.1267425164457
normal, size, elements 8388608 forward 0.00013209104537963867 bandwidth (GB/s) 254.02503177684966
normal, size, elements 16777216 forward 0.0002667689323425293 bandwidth (GB/s) 251.56176699703818
normal, size, elements 33554432 forward 0.0004705166816711426 bandwidth (GB/s) 285.25604559501795

normal_cuda_(output, 0, 1, gen);
at::mul_out(output, output, std);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "normal_out_cuda_cadd", [&] {
auto ones = at::ones_like(output);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In c++, you can use add variant with Scalar other

- func: add(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
. In any case, creating full-size ones is wasteful, you could broadcast a single element ones tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think addcmul would allow you to have a single pointwise operation here instead of separate mul and add.

auto std = std::get<0>(expand_inplace(output, std_.to(kCUDA)));
auto mean = std::get<0>(expand_inplace(output, mean_.to(kCUDA)));
normal_cuda_(output, 0, 1, gen);
at::mul_out(output, output, std);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addcmul?

});
}

void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so this is the second time I'm reviewing the transform from GENERATE_KERNEL2 to this code, and the thing that I found a bit remarkable here, is that the code has gotten a lot longer. Is this fundamental to how distribution_nullary_kernel works or is there something we can factor out here to reduce how much code we have to write in all of these cases?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think following are what looks analogous to me with the old code:

distribution_elementwise_grid_stride_kernel -> GENERATE_KERNEL2
overloads of normal_cuda/normal_out_cuda -> dispatches in CUDAType::_th_normal_out in CUDAType.cpp
*kernel_cuda -> THCTensor_(*) code in THCTensorRandom.cu

The added bit of code is distribution_nullary_kernel which handles the TensorIterator stuff. Original plan was to modify Loops.cuh to have an elementwise_kernel and launch_kernel specialized to run in grid-stride mode, but the added complexity of distributions needing a transform lambda, the idx value to do unrolling and recursively calculating the philox offset (at the beginning of distribution_nullary_kernel) made me copy bits from gpu_nullary_kernel code.

As the for the std::is_same-s, even though code generation with the macros look less line of codes, I am usually against the macros as it is very unreadable. Like, it took a lot of bookkeeping to get every types right just by looking at the code. I do agree with you on the template specialization on the is_same's though. For now it just seemed easy to do.

GENERATE_KERNEL2(generate_normal, float, double mean, double stdv, float, curand_normal, (x * stdv) + mean)
GENERATE_KERNEL2(generate_normal, double, double mean, double stdv, double, curand_normal_double, (x * stdv) + mean)

GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))

GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))

GENERATE_KERNEL2(generate_normal, at::Half, double mean, double stdv, float, curand_normal, (ScalarConvert<float, at::Half>::to((x * stdv) + mean)))
GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))
GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))

#include <THC/generic/THCTensorRandom.cu>
#include <THC/THCGenerateAllTypes.h>

#include <THC/generic/THCTensorRandom.cu>
#include <THC/THCGenerateBoolType.h>

#undef GENERATE_KERNEL1
#undef GENERATE_KERNEL2
.
.
.
#if defined(THC_REAL_IS_DOUBLE)
GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, ceil(log(x) / log(1-p)))
#else
GENERATE_KERNEL1(generate_geometric, scalar_t, double p, float, curand_uniform, (ScalarConvert<float, scalar_t>::to(ceilf(logf(x) / log(1-p)))))
#endif
#if defined(THC_REAL_IS_LONG) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_FLOAT)
#define CURAND64(STATE) (((uint64_t)curand(STATE)) << 32) | (uint64_t)curand(STATE)
GENERATE_KERNEL2(generate_random, scalar_t, int32_t base, uint32_t range, uint32_t, curand, \
    static_cast<scalar_t>(static_cast<int32_t>((x % range) + base)))
GENERATE_KERNEL2(generate_random_64, scalar_t, int64_t base, uint64_t range, uint64_t, CURAND64, \
    static_cast<scalar_t>(static_cast<int64_t>((x % range) + base)))
#elif defined(THC_REAL_IS_HALF)
GENERATE_KERNEL2(generate_random, scalar_t, int32_t base, uint32_t range, uint32_t, curand,
    (ScalarConvert<int32_t, scalar_t>::to(static_cast<int32_t>(x % range + base))))
#else
GENERATE_KERNEL2(generate_random, scalar_t, int32_t base, uint32_t range, uint32_t, curand,
    static_cast<scalar_t>(static_cast<int32_t>(x % range + base)))
#endif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 SGTM

}

Tensor& normal_out_cuda(Tensor& output, const Tensor& mean_, double std, Generator* gen) {
auto mean = std::get<0>(expand_inplace(output, mean_.to(kCUDA)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks wrong. We never do implicit conversions from CPU tensors to CUDA tensors. Is there a reason you added this line here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any indication this happened in the old code either

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just following how the bernoulli kernel was ported from THC to ATen. Don't really know if that's needed or not.

Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
  auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we were doing the expand like this? THCTensor_(resizeAs)(state, self, means);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me try removing them and see what happens.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok looks like removing expand_inplace works fine. Curious to know why it was there in bernoulli though @ssnl

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frankly speaking, I don't know. I copied that behavior when I moved bernoulli to aten. However, maybe it could be related to broadcasting with torch.bernoulli?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, bernoulli_tensor_cuda_kernel cannot handle unexpanded p, however, at::mul and at::add can.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed expand_inplaces.

auto mean = std::get<0>(expand_inplace(output, mean_.to(kCUDA)));
normal_cuda_(output, 0, std, gen);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(output.scalar_type(), "normal_out_cuda_cadd", [&] {
at::add_out(output, output, mean, static_cast<scalar_t>(1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If possible, it's better to use the in-place versions of these functions. I don't think it matters here, but the reason to prefer an actual in-place is that it will guarantee for you that aliasing is handled correctly (in principle, someone can incorrectly implement an _out kernel so that it doesn't work if the output aliases with inputs.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with in-place add_

}

Tensor& normal_out_cuda(Tensor& output, double mean, const Tensor& std_, Generator* gen) {
auto std = std::get<0>(expand_inplace(output, std_.to(kCUDA)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm not sure about here, is why the explicit expand is necessary. Don't functions like addcmul broadcast automatically?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is perhaps the problem because you're directly calling the at::native function, and thus bypassing the logic that handles expansion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, also don't know if explicit expand is necessary. btw, for addcmul, this is what I want to do.
output = mean_tensor + output * std
func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor just addcmul returns a new tensor, so i could have done mean_tensor.addcmul(output, std). But then I would have to copy mean_tensor into output to return it? addcmul_ does similar thing except inplace. And addcmul_out was the only API I found which would accumulate the mean_tensor + output * std in output. That's why I used it. Any comments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with addcmul/addcmul_.

@syed-ahmed syed-ahmed requested a review from ezyang May 22, 2019 22:41
@syed-ahmed
Copy link
Collaborator Author

@ezyang Added the review changes.

Tensor& normal_out_cuda(Tensor& output, double mean, const Tensor& std, Generator* gen) {
normal_cuda_(output, 0, 1, gen);
auto mean_tensor = at::full({1}, mean, output.options());
output = mean_tensor.addcmul_(output, std);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is how *_out functions are supposed to work, because it reassigns output. @ezyang?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All inplace functions always return exactly the same reference as they were passed. So what Syed has done here is harmless, albeit also not necessary.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not fill the memory of output with the correct results, mean_tensor contains correct results, output is assigned to it. It's the same pattern as below, so if below is wrong, this is wrong too. Am I missing something?

return static_cast<scalar_t>(rand * std + mean);
};
if (std::is_same<scalar_t, double>::value) {
distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point (doesn't have to be this PR), I'd really appreciate a proper documentation comment on this function :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distribution_nullary_kernel that is


Tensor& normal_out_cuda(Tensor& output, const Tensor& mean, const Tensor& std, Generator* gen) {
normal_cuda_(output, 0, 1, gen);
output = mean.addcmul(output, std);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This, however, looks wrong. You need to actually literally fill the memory of output with the result of running addcmul here, and this ain't gonna do it. addcmul_out is probably the right thing to use here.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error in addcmul "inplace" invocation

@ezyang
Copy link
Contributor

ezyang commented May 24, 2019 via email

@syed-ahmed
Copy link
Collaborator Author

@ezyang @ngimel Updated with addcmul_out.

@syed-ahmed syed-ahmed requested a review from ezyang May 24, 2019 18:51
@zou3519 zou3519 deleted the gh/syed-ahmed/2/head branch May 31, 2019 01:03
@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in 26d16ae.

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 31, 2019
…ddevs} to ATen (#20621)

Summary:
Pull Request resolved: pytorch/pytorch#20621
ghimport-source-id: f461d7f1eb6b5a8306dd8175cbb0a7fcc9f64c76

Differential Revision: D15454048

Pulled By: ezyang

fbshipit-source-id: 8bfc57bf015b85f57ed99a54176926386aab4e34
@ezyang
Copy link
Contributor

ezyang commented May 31, 2019

Hmm, rather mysteriously, this commit seems to have broken Windows:

05:02:24 ======================================================================
05:02:24 FAIL: test_Conv2d_groups_nobias_v2 (__main__.TestNN)
05:02:24 ----------------------------------------------------------------------
05:02:24 Traceback (most recent call last):
05:02:24   File "test_nn.py", line 4804, in test_Conv2d_groups_nobias_v2
05:02:24     self.assertEqual(output, torch.cat([output1, output2], 1))
05:02:24   File "C:\Jenkins\workspace\pytorch-builds\pytorch-win-ws2016-cuda9-cudnn7-py3-test1\test\common_utils.py", line 503, in assertEqual
05:02:24     assertTensorsEqual(x, y)
05:02:24   File "C:\Jenkins\workspace\pytorch-builds\pytorch-win-ws2016-cuda9-cudnn7-py3-test1\test\common_utils.py", line 495, in assertTensorsEqual
05:02:24     self.assertLessEqual(max_err, prec, message)
05:02:24 AssertionError: tensor(1.5259e-05, device='cuda:0', dtype=torch.float16,
05:02:24        grad_fn=<MaxBackward1>) not less than or equal to 1e-05 : 

@syed-ahmed syed-ahmed restored the gh/syed-ahmed/2/head branch May 31, 2019 20:24
@syed-ahmed syed-ahmed removed the merged label May 31, 2019
@syed-ahmed syed-ahmed reopened this May 31, 2019
@syed-ahmed syed-ahmed changed the base branch from gh/syed-ahmed/2/base to master May 31, 2019 21:21
@syed-ahmed syed-ahmed force-pushed the gh/syed-ahmed/2/head branch from 4bceadc to c22456e Compare May 31, 2019 21:42
@ngimel
Copy link
Collaborator

ngimel commented May 31, 2019

This was likely a flaky test, and using different random inputs triggered its flakyness.

@syed-ahmed syed-ahmed force-pushed the gh/syed-ahmed/2/head branch from b6bbc74 to ba39ee4 Compare May 31, 2019 23:39
@pytorchbot pytorchbot added the module: nn Related to torch.nn label May 31, 2019
@syed-ahmed
Copy link
Collaborator Author

Checking if starting from a known seed helps this test for windows.

@syed-ahmed
Copy link
Collaborator Author

syed-ahmed commented Jun 2, 2019

@ezyang This is ready to be landed again. I changed the test_Conv2d_groups_nobias_v2 to have a manual_seed to tame the flakyness a bit. The rocm failure seems irrelevant as it's showing a segfault with tb-nightly install? Also, please only land this one for now (note I am asking to merge gh/syed-ahmed/2/head into master). I botched the following PRs doing something bad with ghstack. I'll restore them back to how they were and ping you again for them.

@syed-ahmed syed-ahmed requested a review from ezyang June 2, 2019 02:13
@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019

@syed-ahmed I think ghstack is choking because the PR was reverted but you've reused the same pull request to resubmit. I'm not exactly sure how to unwedge it, but as a start, could you try running ghstack again on this commit only?

@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019

Actually, did you even use ghstack to push updates here? The branch pointer on this PR is messed up.

@syed-ahmed
Copy link
Collaborator Author

I did a series of unfortunate things. I didn't realize gh/syed-ahmed/2/orig was deleted. I just restored head and base of normal and Cauchy PR and then ran ghstack and that pushed commits from exponential to the end of the stack.

Now I just restored the reverted normal PR in gh/syed-ahmed/2/head and changed the base of this PR to master, so that we could just let this one be merged the non stack way. I'll do the same thing for Cauchy PR. I'll try to fix the stack starting from the exponential PR.

@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019 via email

@ezyang
Copy link
Contributor

ezyang commented Jun 3, 2019

Hmm, I have to make a fresh PR for this branch.

@ezyang ezyang closed this Jun 3, 2019
facebook-github-bot pushed a commit that referenced this pull request Jun 3, 2019
…ddevs} to ATen (#21287)

Summary:
## Effective Bandwidth Benchmark
- using https://gist.github.com/syed-ahmed/f8b7384d642f4bce484228b508b4bc68
- on V100
### Float Type
#### Before:
```
normal, size, elements 65536 forward 4.956722259521484e-06 bandwidth (GB/s) 52.88656218258779
normal, size, elements 131072 forward 5.285739898681641e-06 bandwidth (GB/s) 99.18914098114568
normal, size, elements 262144 forward 7.548332214355469e-06 bandwidth (GB/s) 138.91492454529376
normal, size, elements 524288 forward 1.1980533599853516e-05 bandwidth (GB/s) 175.0466273076219
normal, size, elements 1048576 forward 2.091646194458008e-05 bandwidth (GB/s) 200.52645667862762
normal, size, elements 2097152 forward 3.9961338043212894e-05 bandwidth (GB/s) 209.91809610901498
normal, size, elements 4194304 forward 7.39765167236328e-05 bandwidth (GB/s) 226.79110538115253
normal, size, elements 8388608 forward 0.0001377725601196289 bandwidth (GB/s) 243.5494555001696
normal, size, elements 16777216 forward 0.0002710080146789551 bandwidth (GB/s) 247.62686107087774
normal, size, elements 33554432 forward 0.0005375170707702637 bandwidth (GB/s) 249.69947058177252
```
#### After:
```
normal, size, elements 65536 forward 6.198883056640625e-06 bandwidth (GB/s) 42.288908760615385
normal, size, elements 131072 forward 6.756782531738281e-06 bandwidth (GB/s) 77.59432800112916
normal, size, elements 262144 forward 7.560253143310547e-06 bandwidth (GB/s) 138.6958849291706
normal, size, elements 524288 forward 7.550716400146485e-06 bandwidth (GB/s) 277.7421225831386
normal, size, elements 1048576 forward 1.1034011840820313e-05 bandwidth (GB/s) 380.1250225673293
normal, size, elements 2097152 forward 1.802682876586914e-05 bandwidth (GB/s) 465.34019427102237
normal, size, elements 4194304 forward 2.8417110443115234e-05 bandwidth (GB/s) 590.3913430460946
normal, size, elements 8388608 forward 4.8711299896240235e-05 bandwidth (GB/s) 688.8428777608927
normal, size, elements 16777216 forward 9.685993194580078e-05 bandwidth (GB/s) 692.8444265018856
normal, size, elements 33554432 forward 0.00018213510513305663 bandwidth (GB/s) 736.9130069787966
```
### Double Type
#### Before:
```
normal, size, elements 65536 forward 5.8841705322265624e-06 bandwidth (GB/s) 44.55071425348461
normal, size, elements 131072 forward 8.018016815185547e-06 bandwidth (GB/s) 65.38873789925661
normal, size, elements 262144 forward 1.2989044189453124e-05 bandwidth (GB/s) 80.72772597474304
normal, size, elements 524288 forward 2.2075176239013673e-05 bandwidth (GB/s) 95.00046465285668
normal, size, elements 1048576 forward 4.1041374206542965e-05 bandwidth (GB/s) 102.19696784254678
normal, size, elements 2097152 forward 7.57598876953125e-05 bandwidth (GB/s) 110.72624650312186
normal, size, elements 4194304 forward 0.00013725996017456056 bandwidth (GB/s) 122.22949779865557
normal, size, elements 8388608 forward 0.0002614736557006836 bandwidth (GB/s) 128.32815569921402
normal, size, elements 16777216 forward 0.0005080199241638184 bandwidth (GB/s) 132.0988819689674
normal, size, elements 33554432 forward 0.0009479570388793945 bandwidth (GB/s) 141.58629821311564
```
#### After:
```
normal, size, elements 65536 forward 5.991458892822265e-06 bandwidth (GB/s) 43.75294977222444
normal, size, elements 131072 forward 7.293224334716797e-06 bandwidth (GB/s) 71.88699756626349
normal, size, elements 262144 forward 8.094310760498048e-06 bandwidth (GB/s) 129.54481623281296
normal, size, elements 524288 forward 1.2805461883544922e-05 bandwidth (GB/s) 163.7701177100726
normal, size, elements 1048576 forward 2.2592544555664064e-05 bandwidth (GB/s) 185.64991604491345
normal, size, elements 2097152 forward 3.801822662353516e-05 bandwidth (GB/s) 220.6470092112881
normal, size, elements 4194304 forward 6.761550903320313e-05 bandwidth (GB/s) 248.1267425164457
normal, size, elements 8388608 forward 0.00013209104537963867 bandwidth (GB/s) 254.02503177684966
normal, size, elements 16777216 forward 0.0002667689323425293 bandwidth (GB/s) 251.56176699703818
normal, size, elements 33554432 forward 0.0004705166816711426 bandwidth (GB/s) 285.25604559501795
```

Resubmit of #20621
Pull Request resolved: #21287

Differential Revision: D15603695

Pulled By: ezyang

fbshipit-source-id: f8c5032678d503d45ac99fb1475a929df7c2b361
@syed-ahmed syed-ahmed deleted the gh/syed-ahmed/2/head branch June 3, 2019 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants