-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[C++ API] AdaptiveLogSoftmaxWithLoss #29076
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might be able to do
| output = output.index_copy_(0, row_indices, local_logprob.squeeze(1)); | |
| output.index_copy_(0, row_indices, local_logprob.squeeze(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was confusing for me as well, as mentioned in another comment .index_copy_() does not update the calling Tensor which apparently it should.
8241bfa to
514988a
Compare
|
@yf225 Thanks for the feedback. I have committed the suggested changes except using target mask for indexing (multi-dimensional tensor indexing) which currently is not supported in the C++ API, with the details mentioned inline. |
2dfb76f to
20d1d54
Compare
|
@yf225 I have committed a few updates along with rebasing on latest code. Can you have a look? |
183206d to
a485104
Compare
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the awesome work! Apologies for the delay as I was waiting for C++ tensor multi-dim indexing to be available, so that we don't need to use narrow / index_select :D
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
💊 CircleCI build failures summary and remediationsAs of commit 521cd95 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following build failures do not appear to be due to upstream breakages:
|
Implemented AdaptiveLogSoftmaxWithLoss and some tests for modules. Reference #25883