-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[sparse] gradcheck #14596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[sparse] gradcheck #14596
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/autograd/gradcheck.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this should be called gradcheck because you are only checking nonzero values rather than the full tensor. :/ I know that we only calculate gradients for the present values, but technically these are not the gradients you would expect for the normal definition of those ops. So having the gradcheck default to assuming this seems wrong to me. Could you add a flag to gradcheck that enables this and clarify in the doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, this is a valid point. Can I make it sparse_gradcheck and add doc for it? Although most of the code in this fn will be dups of gradcheck. What do you think? I guess you also know why I am only checking nnz values here: I don't know how to permute zero values in a sparse tensor...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about adding a bool flag to gradcheck that would turn on/off checking only nonzero values for sparse inputs. I don't know if that's the best solution though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, that's doable too. I can add a is_sparse flag to gradcheck
|
It should be called something like sparse_only_checknonzero. Maybe
something shorter. When this is false, seeing sparse inputs should raise
and tell people that normal sparse input gradcheck is not implemented.
…On Sun, Dec 2, 2018 at 03:35 Wei Yang ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In torch/autograd/gradcheck.py
<#14596 (comment)>:
> + def get_stride(size):
+ dim = len(size)
+ tmp = 1
+ stride = [0] * dim
+ for i in reversed(range(dim)):
+ stride[i] = tmp
+ tmp *= size[i]
+ return stride
+
+ x_nnz = x_tensor._nnz()
+ x_size = list(x_tensor.size())
+ x_indices = x_tensor._indices().t()
+ x_values = x_tensor._values().data
+ x_stride = get_stride(x_size)
+
+ for i in range(x_nnz):
yes, that's doable too. I can add a is_sparse flag to gradcheck
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#14596 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AFaWZXbkt0Krt3Ru1tXmAZaa6-Qcpm1uks5u0toYgaJpZM4Y63VE>
.
|
e87e9da to
c9a16ae
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
c9a16ae to
4963e03
Compare
…se output yet; 2. impl backward for to_dense() to get around sparse output
4963e03 to
ec86fb1
Compare
…y for SparseTensor inputs
ec86fb1 to
ebba9c5
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/autograd/gradcheck.py
Outdated
| raise_exception (bool, optional): indicating whether to raise an exception if | ||
| the check fails. The exception gives more information about the | ||
| exact nature of the failure. This is helpful when debugging gradchecks. | ||
| check_sparse_nnz (bool, optional): if True, gradcheck allows for SparesTensor input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparesTensor -> SparseTensor
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weiyangfb has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
to_dense()to get around sparse output_gen_sparse()and also easily cover coalesced / uncoalesced test cases