-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Fix the slowness of mvn's log_prob #17294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
t-vi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems good. Thank you!
|
|
||
| bx_batch_shape = bx.shape[:-1] | ||
| # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), | ||
| # we are going to make bx have shape (..., i, 1, n) to apply _batch_trtrs_lower |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this description still accurate?
| # Reshape bx with the shape (..., 1, i, j, 1, n) | ||
| bx_new_shape = bx.shape[:outer_batch_dims] | ||
| for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): | ||
| bx_new_shape += (sx // sL, sL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever with the two dims!
I'm not entirely convinced that amending the broadcasting semantics is a good idea, though, unless you have a specific use case. People will start to depend on it in obscure fashions and when we get a batch triangular solver, you won't be able to replace this code.
Here (+ the reshape) will cause stride 0 dimensions of L to be expanded, but I guess we're not too concerned about people having used expand beforehand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that when we have batch version of triangular solver, replacing the method batch_trtrs_lower might be enough. If batch triangular solver also handles broadcasting, then we can remove these reshape+permute mechanism.
Here (+ the reshape) will cause stride 0 dimensions of L to be expanded, but I guess we're not too concerned about people having used expand beforehand.
Yeah, agree. In mvn, the math involving scale_tril mostly depends on unbroadcasted (unexpaned) version. So to get a better performance, users should not expand scale_tril/covariance_matrix before creating mvn distribution.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR addresses the slowness of MVN's log_prob as reported in #17206.
@t-vi I find it complicated to handle permutation dimensions if we squeeze singleton dimensions of bL, so I leave it as-is and keep the old approach. What do you think?