Skip to content

Conversation

@syed-ahmed
Copy link
Collaborator

@syed-ahmed syed-ahmed commented May 17, 2019

Stack from ghstack:

Differential Revision: D15454049

Effective Bandwidth Benchmark

Float Type

Before:

geometric, size, elements 65536 forward 4.827976226806641e-06 bandwidth (GB/s) 54.296870507456795
geometric, size, elements 131072 forward 5.9986114501953125e-06 bandwidth (GB/s) 87.40156023656598
geometric, size, elements 262144 forward 9.603500366210938e-06 bandwidth (GB/s) 109.18685479404171
geometric, size, elements 524288 forward 1.6007423400878906e-05 bandwidth (GB/s) 131.01121570163838
geometric, size, elements 1048576 forward 2.911090850830078e-05 bandwidth (GB/s) 144.08014778391484
geometric, size, elements 2097152 forward 5.525588989257812e-05 bandwidth (GB/s) 151.81382502947878
geometric, size, elements 4194304 forward 0.00010294198989868164 bandwidth (GB/s) 162.9773818877273
geometric, size, elements 8388608 forward 0.0001985597610473633 bandwidth (GB/s) 168.98908330170744
geometric, size, elements 16777216 forward 0.00038609743118286135 bandwidth (GB/s) 173.8132879941806
geometric, size, elements 33554432 forward 0.0007671475410461426 bandwidth (GB/s) 174.9568639912085

After:

geometric, size, elements 65536 forward 5.98907470703125e-06 bandwidth (GB/s) 43.7703673477707
geometric, size, elements 131072 forward 5.676746368408203e-06 bandwidth (GB/s) 92.3571295905922
geometric, size, elements 262144 forward 6.127357482910156e-06 bandwidth (GB/s) 171.13021443984437
geometric, size, elements 524288 forward 7.076263427734375e-06 bandwidth (GB/s) 296.3643201552561
geometric, size, elements 1048576 forward 1.0535717010498046e-05 bandwidth (GB/s) 398.1033275495814
geometric, size, elements 2097152 forward 1.7604827880859376e-05 bandwidth (GB/s) 476.49474659848323
geometric, size, elements 4194304 forward 2.9888153076171875e-05 bandwidth (GB/s) 561.333313478494
geometric, size, elements 8388608 forward 5.422115325927734e-05 bandwidth (GB/s) 618.8439378916895
geometric, size, elements 16777216 forward 0.00010248422622680665 bandwidth (GB/s) 654.8213951626288
geometric, size, elements 33554432 forward 0.00019872665405273437 bandwidth (GB/s) 675.388657046396

Double Type

Before:

geometric, size, elements 65536 forward 7.531642913818359e-06 bandwidth (GB/s) 34.80568622272872
geometric, size, elements 131072 forward 7.486343383789062e-06 bandwidth (GB/s) 70.03258775643313
geometric, size, elements 262144 forward 1.2500286102294922e-05 bandwidth (GB/s) 83.8841600439443
geometric, size, elements 524288 forward 2.1970272064208986e-05 bandwidth (GB/s) 95.45407511891482
geometric, size, elements 1048576 forward 4.1151046752929686e-05 bandwidth (GB/s) 101.9246004890846
geometric, size, elements 2097152 forward 7.607698440551757e-05 bandwidth (GB/s) 110.26472809812907
geometric, size, elements 4194304 forward 0.00013311147689819335 bandwidth (GB/s) 126.03883895625013
geometric, size, elements 8388608 forward 0.00026131629943847655 bandwidth (GB/s) 128.40543078293493
geometric, size, elements 16777216 forward 0.0005186843872070312 bandwidth (GB/s) 129.38284948456277
geometric, size, elements 33554432 forward 0.0010293865203857423 bandwidth (GB/s) 130.38613323759532

After:

geometric, size, elements 65536 forward 6.048679351806641e-06 bandwidth (GB/s) 43.33904721229799
geometric, size, elements 131072 forward 7.328987121582031e-06 bandwidth (GB/s) 71.5362152098894
geometric, size, elements 262144 forward 1.009225845336914e-05 bandwidth (GB/s) 103.89904349407041
geometric, size, elements 524288 forward 1.6951560974121092e-05 bandwidth (GB/s) 123.71438849800283
geometric, size, elements 1048576 forward 3.087997436523438e-05 bandwidth (GB/s) 135.82601949054973
geometric, size, elements 2097152 forward 5.675792694091797e-05 bandwidth (GB/s) 147.7962366161136
geometric, size, elements 4194304 forward 0.00010924100875854492 bandwidth (GB/s) 153.57983408119776
geometric, size, elements 8388608 forward 0.0002037382125854492 bandwidth (GB/s) 164.69385675957594
geometric, size, elements 16777216 forward 0.0003897523880004883 bandwidth (GB/s) 172.18332989384
geometric, size, elements 33554432 forward 0.0007770538330078125 bandwidth (GB/s) 172.72642164375063

Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
@syed-ahmed syed-ahmed requested a review from ezyang May 29, 2019 20:09
auto p = static_cast<float>(p_);
auto geometric_func = [p] __device__ (float rand) {
// use __logf fast approximation for peak bandwidth
return static_cast<scalar_t>(::ceil(__logf(rand) / __logf(static_cast<float>(1.0)-p)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used to be ceilf, I guess it doesn't make a difference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. ceil with float input will launch ceilf. These overloads are in /usr/local/cuda-10.0/include/crt/math_functions.hpp fyi

Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
Move THCTensor_(geometric) to ATen

gh-metadata: pytorch pytorch 20625 gh/syed-ahmed/6/head
@syed-ahmed syed-ahmed closed this Jun 3, 2019
@syed-ahmed syed-ahmed deleted the gh/syed-ahmed/6/head branch June 3, 2019 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants