Skip to content

Conversation

@mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Dec 14, 2018

Followup PR of #14904, and the stretch goal of #12653.

Directly calculate coordinates in the original tensor using column index in the result tensor. Every GPU thread takes care of a column (two numbers) in the output tensor.

The implementation detects and handles precision loss during calculating the square root of a int64_t variable, and supports tensors with up to row * column = 2 ^ 59 numbers.

Algorithm details are describe in comments of TensorFactories.cu.

@zou3519

@mrshenli mrshenli changed the title [WIP] Implementing cuda kernel for tril_indices Implementing cuda kernel for tril_indices Dec 14, 2018
@mrshenli
Copy link
Contributor Author

There seems to be some non-deterministic error. The test_cuda.py occasionally fail.

08:19:52 ======================================================================
08:19:52 FAIL: test_tril_and_triu_indices (__main__.TestCuda)
08:19:52 ----------------------------------------------------------------------
08:19:52 Traceback (most recent call last):
08:19:52   File "/var/lib/jenkins/workspace/test/common_utils.py", line 290, in wrapper
08:19:52     method(*args, **kwargs)
08:19:52   File "test_cuda.py", line 2146, in test_tril_and_triu_indices
08:19:52     self._compare_trilu_indices(*test_args)
08:19:52   File "test_cuda.py", line 2137, in _compare_trilu_indices
08:19:52     torch.tril_indices(row, col, offset, dtype=dtype, device='cuda'))
08:19:52   File "/var/lib/jenkins/workspace/test/common_utils.py", line 414, in assertEqual
08:19:52     assertTensorsEqual(x, y)
08:19:52   File "/var/lib/jenkins/workspace/test/common_utils.py", line 406, in assertTensorsEqual
08:19:52     self.assertLessEqual(max_err, prec, message)
08:19:52 AssertionError: tensor(254, device='cuda:0') not less than or equal to 1e-05 :

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 14, 2018

hmm, I cannot reproduce this error... caught it

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 14, 2018

It seems to be a bug in existing code:

>>> torch.ones(2028, 1, dtype=torch.long, device='cuda').tril(0)
tensor([[1],
        [0],
        [0],
        ...,
        [0],
        [0],
        [0]], device='cuda:0')
>>> torch.ones(2028, 1, dtype=torch.long, device='cpu').tril(0)
tensor([[1],
        [1],
        [1],
        ...,
        [1],
        [1],
        [1]])

UPDATE:
Issue reported in #15226.

@zou3519
Copy link
Contributor

zou3519 commented Dec 14, 2018

Can you run some performance numbers comparing torch.tril_indices on the gpu and on CPU with numpy.tril_indices? Quick sanity check that the gpu acceleration actually helps

@mrshenli
Copy link
Contributor Author

Good point! Yes, will do.

@mrshenli
Copy link
Contributor Author

The speedup by GPU is indeed substantial:

X axis: size of the matrix.
Y axis: the total time in seconds of running triu_indices 10 times.

screen shot 2018-12-14 at 2 25 07 pm

torch.tril_indices(row, col, offset, dtype=dtype, device='cuda'))
x = torch.ones(row, col, dtype=dtype, device='cpu') \
.tril(offset).nonzero().transpose(0, 1).cuda()
torch.cuda.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to synchronize here. self.assertEqual should have a sync point somewhere

@mrshenli
Copy link
Contributor Author

It seems to be an irrelevant error:

Dec 15 00:16:38 ======================================================================
Dec 15 00:16:38 ERROR: test_scatter_stress_cuda (__main__.ProcessGroupGlooTest)
Dec 15 00:16:38 ----------------------------------------------------------------------
Dec 15 00:16:38 Traceback (most recent call last):
Dec 15 00:16:38   File "test_c10d.py", line 451, in wrapper
Dec 15 00:16:38     self._join_processes(fn)
Dec 15 00:16:38   File "test_c10d.py", line 496, in _join_processes
Dec 15 00:16:38     self._check_return_codes(elapsed_time)
Dec 15 00:16:38   File "test_c10d.py", line 506, in _check_return_codes
Dec 15 00:16:38     raise RuntimeError('Process {} terminated or timed out after {} seconds'.format(i, elapsed_time))
Dec 15 00:16:38 RuntimeError: Process 0 terminated or timed out after 30.03047275543213 seconds
Dec 15 00:16:38 
Dec 15 00:16:38 ----------------------------------------------------------------------

self._compare_trilu_indices(3, 513, offset=1, dtype=torch.long)
self._compare_trilu_indices(513, 3, offset=1, dtype=torch.int)
self._compare_trilu_indices(513, 0, offset=1, dtype=torch.double)
for test_args in tri_tests_args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can deduplicate _compare_trilu_indices and use it in both TestTorch and TestCuda by having it take a "device" argument.

Here is an example of how to call a function defined in TestTorch from TestCuda:

_TestTorchMixin._test_neg(self, lambda t: t.cuda())
.

# explicitly convert 'cpu' tensor to 'cuda' to avoid a bug in tril
# and triu cuda kernel (see #15226)
x = torch.ones(row, col, dtype=dtype, device='cpu') \
.tril(offset).nonzero().transpose(0, 1).cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To help with the deduplication between this and the _compare_trilu_indices in test_torch.py, you can do

x = torch.ones(row, col, dtype=dtype, device='cpu') \
                     .tril(offset).nonzero().transpose(0, 1).to(device)

here, assuming we pass in a 'device' for _compare_trilu_indices

f <<= 1;
auto b = f - 1;
auto c = - (x << 1);
row = (int64_t) ::floor((-b + ::sqrt(b * b - 4.0 * c))/2);
Copy link
Contributor

@zou3519 zou3519 Dec 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this accumulate in double or float? Instead of floor, would it be better to round the result of this to the nearest integer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about the numeric stability: can we run a test on some input with b * b much larger than 4 * c?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It accumulates in double. I will add some tests for large b*b and small 4 * c.

Regarding the instability, I do see cases where the result was incorrectly rounded up when I was using float. Should I add back the check (if the # of elems up to row was larger than x), or do you know any other more stable way to calculate the row idx? I thought about binary search, but it would be much slower compared to directly calling sqrt.

would it be better to round the result of this to the nearest integer

Do you mean using the __double2ll_rd method?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double precision is good; double arithmetic on GPUs is slower than float arithmetic but it is better that this kernel is correct.

What exactly is unstable here? Is the problem taking the square root of a very big number?

Do you mean using the __double2ll_rd method?

I'm not sure what _double2ll_rd is, but llround might work too: https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH__DOUBLE.html#group__CUDA__MATH__DOUBLE_1g6e401c3a6f291b874fc95b8480bcad02

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using single precision, I encountered an error while calculating indices for a 5k X 5k matrices. The result should be something like 4999.999998, but got rounded up to 5000 as float only has 7-8 digit precision.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! I had two minor comments:

  1. please try to deduplicate the testing code more; I think we can keep one version of it in test/test_torch.py and call it directly from test/test_cuda.py
  2. I'm a little worried about the numeric stability -- can you try a test case where b^2 is much greater than 4c and check the correctness?

@mrshenli
Copy link
Contributor Author

I added code to explicitly handle precision loss when converting int64_t to double [code].

@mrshenli
Copy link
Contributor Author

Even with the latest code, we can only handle cases where row * col < power(2, 59), as we do this in the code. Will this be OK? (I will and notes in the docs)

@zou3519

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 19, 2018

hmmm, our Jenkins server does not have enough RAM to test large tensors. Let me reduce the test size a bit.

Dec 18 23:18:56 ======================================================================
Dec 18 23:18:56 ERROR: test_large_trilu_indices (__main__.TestTorch)
Dec 18 23:18:56 ----------------------------------------------------------------------
Dec 18 23:18:56 Traceback (most recent call last):
Dec 18 23:18:56   File "test_torch.py", line 3785, in test_large_trilu_indices
Dec 18 23:18:56     _compare_large_trilu_indices(self, *test_args, device='cpu')
Dec 18 23:18:56   File "/var/lib/jenkins/workspace/test/common_methods_invocations.py", line 858, in _compare_large_trilu_indices
Dec 18 23:18:56     row, col, offset, dtype=dtype, device=device)[:, -100:-1]
Dec 18 23:18:56 RuntimeError: $ Torch: not enough memory: you tried to allocate 7GB. Buy new RAM! at /var/lib/jenkins/workspace/aten/src/TH/THGeneral.cpp:201

@mrshenli
Copy link
Contributor Author

mrshenli commented Dec 19, 2018

Encountered an irrelevant error:

03:57:58 Running test_nn ... [2018-12-19 03:57:58.674092]
03:58:02 terminate called after throwing an instance of 'ihipException'
03:58:02   what():  std::exception
03:58:02 Traceback (most recent call last):
03:58:02   File "test/run_test.py", line 431, in <module>
03:58:02     main()
03:58:02   File "test/run_test.py", line 423, in main
03:58:02     raise RuntimeError(message)
03:58:02 RuntimeError: test_nn failed! Received signal: SIGIOT

The following larger tensor tests pass on my local server which would hit precision loss problem if we do not handle them specifically:

# row, col
[
    (536870901, 1),
    (1, 536870901),
    (268435455, 2, 1),
    (2, 268435455, 1)
]

@mrshenli
Copy link
Contributor Author

Dec 19 06:49:05 ======================================================================
Dec 19 06:49:05 ERROR: setUpClass (__main__.TestHub)
Dec 19 06:49:05 ----------------------------------------------------------------------
Dec 19 06:49:05 Traceback (most recent call last):
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/connection.py", line 171, in _new_conn
Dec 19 06:49:05     (self._dns_host, self.port), self.timeout, **extra_kw)
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/util/connection.py", line 56, in create_connection
Dec 19 06:49:05     for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/socket.py", line 745, in getaddrinfo
Dec 19 06:49:05     for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
Dec 19 06:49:05 socket.gaierror: [Errno -3] Temporary failure in name resolution
Dec 19 06:49:05 
Dec 19 06:49:05 During handling of the above exception, another exception occurred:
Dec 19 06:49:05 
Dec 19 06:49:05 Traceback (most recent call last):
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/connectionpool.py", line 600, in urlopen
Dec 19 06:49:05     chunked=chunked)
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/connectionpool.py", line 343, in _make_request
Dec 19 06:49:05     self._validate_conn(conn)
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/connectionpool.py", line 849, in _validate_conn
Dec 19 06:49:05     conn.connect()
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/connection.py", line 314, in connect
Dec 19 06:49:05     conn = self._new_conn()
Dec 19 06:49:05   File "/opt/conda/lib/python3.6/site-packages/urllib3/connection.py", line 180, in _new_conn
Dec 19 06:49:05     self, "Failed to establish a new connection: %s" % e)
Dec 19 06:49:05 urllib3.exceptions.NewConnectionError: <urllib3.connection.VerifiedHTTPSConnection object at 0x7f917ece2dd8>: Failed to establish a new connection: [Errno -3] Temporary failure in name 

@zou3519
Copy link
Contributor

zou3519 commented Dec 19, 2018

@mrshenli

Even with the latest code, we can only handle cases where row * col < power(2, 59), as we do this in the code. Will this be OK? (I will and notes in the docs)

that's fine, I doubt users will pick something so large. A lot of the code in our codebase uses int instead of int64 so I think users would hit other problems before they run into this one :)

Copy link
Contributor

@zou3519 zou3519 Dec 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long does this take? It might be better if we don't run all of these in our CI if it takes too long / is too resource intensive (pick your favorite one!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the gpudevserver, it takes ~2 minutes in total.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I realize it is good to have these to test resolve_root_int's fallback binary search path. Let's leave them then unless they take too long to run

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, 2 minutes is too long for a unit test to be running -- could you keep one of these test cases and note the rest of them in a comment in the code somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will comment out all tests but one, and add descriptions in TensorFactories.cu to remind developers if they would like to make some changes.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thank you @mrshenli!

Left a minor comment about testing but the code looks good

@mrshenli mrshenli changed the title Implementing cuda kernel for tril_indices Implementing cuda kernel for tril_indices and triu_indices Dec 19, 2018
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Dec 20, 2018
Summary:
Followup PR of #14904, and the stretch goal of #12653.

Directly calculate coordinates in the original tensor using column index in the result tensor. Every GPU thread takes care of a column (two numbers) in the output tensor.

The implementation detects and handles precision loss during calculating the square root of a `int64_t` variable, and supports tensors with up to `row * column = 2 ^ 59` numbers.

Algorithm details are describe in [comments of TensorFactories.cu](https://github.com/pytorch/pytorch/blob/23ddb6f58a1c8a7a660a793f174cf014230176c6/aten/src/ATen/native/cuda/TensorFactories.cu#L109-L255).

zou3519
Pull Request resolved: pytorch/pytorch#15203

Reviewed By: zou3519

Differential Revision: D13517695

Pulled By: mrshenli

fbshipit-source-id: 86b305d22cac08c8962a3b0cf8e9e620b7ec33ea
@mrshenli mrshenli deleted the tril_indices_cuda branch January 13, 2019 16:49
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants