-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Cleanup gumbel_softmax #13339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cleanup gumbel_softmax #13339
Conversation
- Cleanup : faster and much less lines of code. - Remove some references in docstring. Was linked ipynb the first to do hard trick? - Add reference to the two papers published in parallell. - Add deprecationwarning for `eps`. It probably shouldn't be there. - Keep some code alternatives commented, I want to exploit the pytorchtester
|
TODO:
|
- Tests same things as before except; - Removed test for variance of gradient. Was unclear how/why it worked. - Tests for 3-dimensional array - `st`->`straight_through`
|
Only error with original test was caused by flake8 (hindering other pull requests to merge currently) |
Previously only tested cpu for `float` but now tests for double and float
- remove unnecessary unpacking of *shape - remove commented torch.distributions code
In addition, flake8
|
I think this is ready to be merged.
Here is a comparison between implementations with full training example https://gist.github.com/ragulpr/1f88aee9fcc585600280ba2d6f323368 Or nbviewer if above link not working: |
|
Anyone who could review? I know @soumith reviewed original PR 👍 |
|
It would be nice if we keep implementations of |
|
Thanks @alicanb I've been looking into how that would be possible. I think it's two things; consistent in values or consistent in style. Long story short, the line-by-line translation using reference implementation Alt a)But this is going to be slow and Alt b)Third alternative is to impose the more legible 'style' of relaxed_categorical rsample implementation but use Alt c)For reference, this was what I committed; Alt d)And using finally using In the end, if we want to mimic in values I propose going with alt b) otherwise alt c) |
|
@alicanb Do you have an opinion about the implementation strategies? |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Add test coverage to `_test_gumbel_softmax_st_shapes` for cuda
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
It was done like that for legacy reasons, but really there's no point in it. This way is more py(torch)tonic
|
I'm not sure what's going on with the failures, because the mentioned code doesn't show up in the PR at all. Pushed a merge to master and trying again. |
fixes Jit-error `return statements can appear only at the end of the function body`. Unfortunately uses the less nice `ret` pattern (but it seems to be standard)
|
Unfortunately, changes in JIT mean that this patch no longer works, because @wanchaol could you take a look at this? |
|
Yes this will make jit standard lib break , @ragulpr why we ended up choosing Alt b) instead of Alt c) given that the Alt c) looks simpler and faster? |
|
Sorry for slow activity here, got a little bit taken back by all JIT-errors. The reason for choosing b) Instead of c) Was because alicanb had the good point that implementation be similar to |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Fixes #12643, amends to #3341.
(but apply softmax overwithdim=-1)dimargumenttorch.Distributions)eps. It's not needed anymore.Controversies:
gumbel_softmaxwas added (add gumbel_softmax, based on Eric Jang's implementation #3341), this was merged intotorch.nn.functionalbefore all the work withDistributionsandPyro, and there will probably be multiple other best practices for this in the future.I've tested building using the
Distributions-api, but it was too slow, see below.I therefore propose not using
Distributionsto keep it fast and simple, but adding a comment in docstring thatgumbel_softmaxmay be deprecated in the future.Build on
torch.distributions.RelaxedOneHotCategorical?Pros:
logsumexpetctorch.distributions.utils._finfoto avoid overflow (old implementation had anepsflag)Cons:
TransformedDistributionandDistribution.dimto apply softmax over.Build on
torch.exponential_()(as proposed)?Pros:
Timings:
TODO
Decide on which path to chose. I'll commit in changes to the unit tests in a while to show that it passes both old tests and new tests. I'll also remove the commented code about
RelaxedOneHotCategorical