Skip to content

Conversation

@zasdfgbnm
Copy link
Collaborator

@zasdfgbnm zasdfgbnm commented Nov 8, 2019

  • Building BinaryOpsKernel.cu takes extremely long. Split the original file into 3 pieces, and copy-paste code into these files.
  • Remove some useless logic
  • change some wrong ops name *_cpu -> *_cuda

@zasdfgbnm zasdfgbnm changed the title Building BinaryOpsKernel.cu takes a long time, parallelize it [WIP] Building BinaryOpsKernel.cu takes a long time, parallelize it Nov 8, 2019
@zasdfgbnm zasdfgbnm changed the title [WIP] Building BinaryOpsKernel.cu takes a long time, parallelize it Improving BinaryOpsKernel.cu Nov 8, 2019
@zasdfgbnm zasdfgbnm changed the title Improving BinaryOpsKernel.cu [WIP] Improving BinaryOpsKernel.cu Nov 8, 2019
}

void logical_xor_kernel_cuda(TensorIterator& iter) {
if (iter.common_dtype() == ScalarType::Bool) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is useless

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like so. Perhaps deleting this logical for all comparison ops will speed up sufficiently without the need to split.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuhdev Even splitted, BinaryCompareKernel.cu still takes 2min 44s to compile

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If every function in BinaryCompareKernel.cu is cut into half, then it may be reduced down to a reasonable time. I believe the functions in BinaryCompareKernel.cu might be the bottleneck.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this PR already cut it into half, but it still takes more than 2 minutes...


void lt_kernel_cuda(TensorIterator& iter) {
if (iter.common_dtype() == ScalarType::Bool) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.input_dtype(), "lt_cpu", [&]() {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be lt_cuda

}

void lt_kernel_cuda(TensorIterator& iter) {
if (iter.common_dtype() == ScalarType::Bool) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is useless either. With the dynamic casting approach in TensorIterator, it always does the computation in common dtype and stores the result as the common dtype and then dynamically cast it into bool.

@zasdfgbnm zasdfgbnm changed the title [WIP] Improving BinaryOpsKernel.cu Improving BinaryOpsKernel.cu Nov 8, 2019
@zasdfgbnm
Copy link
Collaborator Author

@ngimel @VitalyFedyunin Could you please take a look at this? You reviewed the dynamic casting of TensorIterator.

VitalyFedyunin
VitalyFedyunin previously approved these changes Nov 8, 2019
Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to add benchmark results for changed operators like logical_xor_kernel_cuda

@VitalyFedyunin VitalyFedyunin dismissed their stale review November 8, 2019 22:12

wrong button pressed

@zasdfgbnm
Copy link
Collaborator Author

@VitalyFedyunin Benchmarks shows the performance change very little:

import torch
print(torch.__version__)
print(torch.version.git_version)
print()
print('=' * 20)


for size in [10, 1000000, 100000000]:
    for dtype in [torch.float, torch.bool]:
        print('size:', size, ', dtype:', dtype)
        a = torch.randn(size, device='cuda').to(dtype)
        print('compare ops')
        torch.cuda.synchronize()
        %timeit a < a; torch.cuda.synchronize()
        print('logical_xor')
        torch.cuda.synchronize()
        %timeit torch.logical_xor(a, a); torch.cuda.synchronize()
        print()
    print('-' * 20)

before

1.4.0a0+1dd3c8e
1dd3c8e53909d6cf35ade5cf85cd7430e5c655f9

====================
size: 10 , dtype: torch.float32
compare ops
20.9 µs ± 315 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
20.2 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

size: 10 , dtype: torch.bool
compare ops
21 µs ± 192 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
19.8 µs ± 97.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

--------------------
size: 1000000 , dtype: torch.float32
compare ops
23.7 µs ± 421 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
23.9 µs ± 296 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

size: 1000000 , dtype: torch.bool
compare ops
20.7 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
21.3 µs ± 190 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

--------------------
size: 100000000 , dtype: torch.float32
compare ops
709 µs ± 354 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
logical_xor
713 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

size: 100000000 , dtype: torch.bool
compare ops
471 µs ± 118 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
logical_xor
471 µs ± 446 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

--------------------

after

1.4.0a0+309f6d6
309f6d6a9c53e9c0c091e8ea809b7582af9d185d

====================
size: 10 , dtype: torch.float32
compare ops
20.4 µs ± 59.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
19 µs ± 59.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

size: 10 , dtype: torch.bool
compare ops
20.5 µs ± 94.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
19 µs ± 54 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

--------------------
size: 1000000 , dtype: torch.float32
compare ops
23.6 µs ± 218 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
23.1 µs ± 66.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

size: 1000000 , dtype: torch.bool
compare ops
20.3 µs ± 40.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
logical_xor
20.5 µs ± 74 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

--------------------
size: 100000000 , dtype: torch.float32
compare ops
707 µs ± 212 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
logical_xor
712 µs ± 533 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

size: 100000000 , dtype: torch.bool
compare ops
472 µs ± 283 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
logical_xor
472 µs ± 210 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

--------------------

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Nov 11, 2019

@VitalyFedyunin Is the internal failure real?

@zasdfgbnm zasdfgbnm deleted the split branch November 11, 2019 22:48
zdevito pushed a commit to zdevito/ATen that referenced this pull request Nov 11, 2019
Summary:
- Building `BinaryOpsKernel.cu` takes extremely long. Split the original file into 3 pieces, and copy-paste code into these files.
- Remove some useless logic
- change some wrong ops name `*_cpu` -> `*_cuda`
Pull Request resolved: pytorch/pytorch#29428

Differential Revision: D18408858

Pulled By: VitalyFedyunin

fbshipit-source-id: 29323b0bc40a928ae698345ad1ffe46c5851b012
@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in 01ad2bc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants