-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add .expand() method to distribution classes #11341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
test/test_distributions.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/gumbel.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/bernoulli.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I think this addresses all comments, and is ready for another review. One thing that is worth doing as suggested by @fritzo (either here or in a separate PR) is to call |
|
@neerajprad Calling |
torch/distributions/distribution.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
@vishwakftw - We can support |
torch/distributions/distribution.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_distributions.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/distribution.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Can you either create an issue describing this performance bug in |
Will create an issue, and point to it. It is probably not a bug though, and may just be due to the overhead of having to deal with non contiguous parameters. |
81ee7cb to
806e268
Compare
|
@vishwakftw, @fritzo : I opened a separate issue for discussion here - #11389. |
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, looks great!
I'm looking forward to cheaper .expand() methods in Pyro (and to the deletion of ReshapedDistribution).
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soumith has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Thanks for reviewing, @fritzo, @vishwakftw! |
|
@soumith - Can we get this merged into master? I would like to test the integration on the Pyro side once it is landed. |
|
sure, i'm merging it now. |
Summary: This PR: - adds a `.expand` method for `TransformedDistribution` along the lines of #11341. - uses this method to simplify `.expand` in distribution classes that subclass off of `TransformedDistribution`. - restores testing of `TransformedDistribution` fixtures. - fixes some bugs wherein we were not setting certain attributes in the expanded instances, and adds tests for `.mean` and `.variance` which use these attributes. There are many cases where users directly use `TransformedDistribution` rather than subclassing off it. In such cases, it seems rather inconvenient to have to write a separate class just to define a `.expand` method. The default implementation should suffice in these cases. cc. fritzo, vishwakftw, alicanb Pull Request resolved: #11607 Differential Revision: D9818225 Pulled By: soumith fbshipit-source-id: 2c4b3812b9a03e6985278cfce0f9a127ce536f23
This adds a
.expandmethod for distributions that is akin to thetorch.Tensor.expandmethod for tensors. It returns a new distribution instance with batch dimensions expanded to the desiredbatch_shape. Since this callstorch.Tensor.expandon the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.e.g.
Motivation
We have already been using the
.expandmethod in Pyro in our patch oftorch.distributions. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.Note that currently, there is no convenient and efficient way to expand distribution instances:
TransformedDistribution(or wrap over another distribution instance. e.g.OneHotCategoricaluses aCategoricalinstance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.__init__.pythat can be avoided.The
.expandmethod allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses__init__.py(using__new__and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via__init__(that said, thesampleandlog_probmethods will probably be the rate determining steps in many applications).e.g.
cc. @fritzo, @apaszke, @vishwakftw, @alicanb