-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Address precision matrix instability of MVN distribution #21366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@pytorchbot merge this please |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Is this error fixed? I am getting this error: |
|
@rojinsafavi This PR enhances the stability for precision_matrix. It look like your |
|
So I'm actually using the same thing that I use in "R" kldiv function, and it works fine over there, is something different here? |
This is the diagonal, no negative number:
|
|
I think positive diagonal is not equivalent to positive definite. You can check if |
Thanks for the help, I was not using the right covariance matrix. The result is a bit different from what I get in R though. In R the KL is 1.955039, and in torch it is tensor(1.9948). Just to make sure that I'm doing everything as I should, here are the steps: |
|
@rojinsafavi I think the way you use kl_divergence is correct. |
Currently, when the input of MVN is precision matrix, we take inverse to convert the result to covariance matrix. This, however, will easily make the covariance matrix not positive definite, hence will trigger a cholesky error.
For example,
will trigger
RuntimeError: cholesky_cpu: U(8,8) is zero, singular U.This PR uses some math tricks (ref) to only take inverse of a triangular matrix, hence increase the stability.
cc @fritzo, @neerajprad , @ssnl