Skip to content

Conversation

@neerajprad
Copy link
Contributor

@neerajprad neerajprad commented Sep 6, 2018

This adds a .expand method for distributions that is akin to the torch.Tensor.expand method for tensors. It returns a new distribution instance with batch dimensions expanded to the desired batch_shape. Since this calls torch.Tensor.expand on the distribution's parameters, it does not allocate new memory for the expanded distribution instance's parameters.

e.g.

>>> d = dist.Normal(torch.zeros(100, 1), torch.ones(100, 1))
>>> d.sample().shape
  torch.Size([100, 1])
>>> d.expand([100, 10]).sample().shape
  torch.Size([100, 10])

Motivation

We have already been using the .expand method in Pyro in our patch of torch.distributions. We use this in our models to enable dynamic broadcasting. This has also been requested by a few users on the distributions slack, and we believe will be useful to the larger community.

Note that currently, there is no convenient and efficient way to expand distribution instances:

  • Many distributions use TransformedDistribution (or wrap over another distribution instance. e.g. OneHotCategorical uses a Categorical instance) under the hood, or have lazy parameters. This makes it difficult to collect all the relevant parameters, broadcast them and construct new instances.
  • In the few cases where this is even possible, the resulting implementation would be inefficient since we will go through a lot of broadcasting and args validation logic in __init__.py that can be avoided.

The .expand method allows for a safe and efficient way to expand distribution instances. Additionally, this bypasses __init__.py (using __new__ and populating relevant attributes) since we do not need to do any broadcasting or args validation (which was already done when the instance was first created). This can result in significant savings as compared to constructing new instances via __init__ (that said, the sample and log_prob methods will probably be the rate determining steps in many applications).

e.g.

>>> a = dist.Bernoulli(torch.ones([10000, 1]), validate_args=True)

>>> %timeit a.expand([10000, 100])
15.2 µs ± 224 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

>>> %timeit dist.Bernoulli(torch.ones([10000, 100]), validate_args=True)
11.8 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

cc. @fritzo, @apaszke, @vishwakftw, @alicanb

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@neerajprad
Copy link
Contributor Author

I think this addresses all comments, and is ready for another review.

One thing that is worth doing as suggested by @fritzo (either here or in a separate PR) is to call .contiguous() inside .expand() which will result in faster .sample() and .log_prob(). I am debating whether we should do this by default, or provide another kwarg contiguous=False.

@vishwakftw
Copy link
Contributor

@neerajprad Calling .contiguous() will probably not satisfy the objective of this PR, because of the fact that .contiguous() creates a contiguous copy of the tensor, meaning more memory allocation, which is not intended if I understand correctly.

This comment was marked as off-topic.

This comment was marked as off-topic.

@neerajprad
Copy link
Contributor Author

Calling .contiguous() will probably not satisfy the objective of this PR, because of the fact that .contiguous() creates a contiguous copy of the tensor, meaning more memory allocation, which is not intended if I understand correctly.

@vishwakftw - We can support Distribution.expand by either reusing existing params or creating new ones altogether (even this should be faster than going through init), though it is preferable to not allocate additional memory whenever possible. .contiguous() will do exactly that - it will not create a new copy if the existing one is already contiguous. In many applications, memory may not be an issue, but the higher time taken to draw or score samples from expanded instances might be (specially as it completely avoidable). e.g. if we expand an instance and are calling .sample multiple times. In such cases, the overhead of using non contiguous parameters is high enough that it is better to pay a one time cost of calling .contiguous. I am happy to defer this to a separate PR / discuss in another issue though.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fritzo
Copy link
Collaborator

fritzo commented Sep 7, 2018

the higher time taken to draw or score samples from expanded instances

Can you either create an issue describing this performance bug in torch.bernoulli(x.expand(...)), or point to an existing issue to provide context? I think @vishwakftw and I are is missing context of timing experiments you've performed regarding .contiguous().

@neerajprad
Copy link
Contributor Author

Can you either create an issue describing this performance bug in torch.bernoulli(x.expand(...)), or point to an existing issue to provide context? I think @vishwakftw and I are is missing context of timing experiments you've performed regarding .contiguous().

Will create an issue, and point to it. It is probably not a bug though, and may just be due to the overhead of having to deal with non contiguous parameters.

@neerajprad
Copy link
Contributor Author

@vishwakftw, @fritzo : I opened a separate issue for discussion here - #11389.

Copy link
Collaborator

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, looks great!

I'm looking forward to cheaper .expand() methods in Pyro (and to the deletion of ReshapedDistribution).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@neerajprad
Copy link
Contributor Author

Thanks for reviewing, @fritzo, @vishwakftw!

@neerajprad
Copy link
Contributor Author

@soumith - Can we get this merged into master? I would like to test the integration on the Pyro side once it is landed.

@soumith
Copy link
Contributor

soumith commented Sep 11, 2018

sure, i'm merging it now.

facebook-github-bot pushed a commit that referenced this pull request Sep 14, 2018
Summary:
This PR:
 - adds a `.expand` method for `TransformedDistribution` along the lines of #11341.
 - uses this method to simplify `.expand` in distribution classes that subclass off of `TransformedDistribution`.
 - restores testing of `TransformedDistribution` fixtures.
 - fixes some bugs wherein we were not setting certain attributes in the expanded instances, and adds tests for `.mean` and `.variance` which use these attributes.

There are many cases where users directly use `TransformedDistribution` rather than subclassing off it. In such cases, it seems rather inconvenient to have to write a separate class just to define a `.expand` method. The default implementation should suffice in these cases.

cc. fritzo, vishwakftw, alicanb
Pull Request resolved: #11607

Differential Revision: D9818225

Pulled By: soumith

fbshipit-source-id: 2c4b3812b9a03e6985278cfce0f9a127ce536f23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants