Skip to content

RunTimeError: torch.det and torch.lu does not support automatic differentiation for outputs with complex dtype #52891

@mzzhang95

Description

@mzzhang95

🐛 There are two bugs:

Error 1: RunTimeError: torch.det does not support automatic differentiation for outputs with complex dtype.
Error 2: To avoid Error 1, I also tried to calculate determinant of a complex matrix by using LU-Decomposition. However, lu.backforward() only supports floating data type. I am sure that my matrix has full rank. That problems only occur if input matrix is a tensor which requires auto_grad.

Just wondering if there is any plan for supporting complex auto_grad in the future. Thanks a lot for your help!!!

To Reproduce

Steps to reproduce Error 1:

x = torch.randn(4, 4, requires_grad=True, dtype=torch.cdouble)
torch.det(x)

Error messages for Error 1:
RuntimeError: det does not support automatic differentiation for outputs with complex dtype.

Steps to reproduce Error 2:

A = torch.ones(4, 4, requires_grad=True, dtype=torch.cdouble)
A_LU, pivots = A.lu()

Error messages for Error 2:
ValueError: lu.backward works only with batches of squared full-rank matrices of floating types.

Expected behavior

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0): nightly 1.9.0
  • OS (e.g., Linux): Windows 10
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.7
  • CUDA/cuDNN version: CUDA 10.2, CUDNN 8
  • GPU models and configuration: GTX 1080Ti
  • Any other relevant information:

Additional context

cc @ezyang @anjali411 @dylanbespalko @mruberry @jianyuh @nikitaved @pearu @heitorschueroff @walterddr @IvanYashchuk

Metadata

Metadata

Assignees

Labels

complex_autogradmodule: complexRelated to complex number support in PyTorchmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions