Skip to content

Conversation

@fehiepsi
Copy link
Contributor

This PR addresses the slowness of MVN's log_prob as reported in #17206.

@t-vi I find it complicated to handle permutation dimensions if we squeeze singleton dimensions of bL, so I leave it as-is and keep the old approach. What do you think?

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good. Thank you!


bx_batch_shape = bx.shape[:-1]
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# we are going to make bx have shape (..., i, 1, n) to apply _batch_trtrs_lower
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this description still accurate?

# Reshape bx with the shape (..., 1, i, j, 1, n)
bx_new_shape = bx.shape[:outer_batch_dims]
for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
bx_new_shape += (sx // sL, sL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever with the two dims!

I'm not entirely convinced that amending the broadcasting semantics is a good idea, though, unless you have a specific use case. People will start to depend on it in obscure fashions and when we get a batch triangular solver, you won't be able to replace this code.

Here (+ the reshape) will cause stride 0 dimensions of L to be expanded, but I guess we're not too concerned about people having used expand beforehand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that when we have batch version of triangular solver, replacing the method batch_trtrs_lower might be enough. If batch triangular solver also handles broadcasting, then we can remove these reshape+permute mechanism.

Here (+ the reshape) will cause stride 0 dimensions of L to be expanded, but I guess we're not too concerned about people having used expand beforehand.

Yeah, agree. In mvn, the math involving scale_tril mostly depends on unbroadcasted (unexpaned) version. So to get a better performance, users should not expand scale_tril/covariance_matrix before creating mvn distribution.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants