Skip to content

torch.matmul does not work for complex numbers #46546

@aneeshnema

Description

@aneeshnema

🐛 Bug

I was trying to implement torch.nn.Linear for complex numbers by reusing the source of torch.nn.Linear and making changes in it. On giving my new layer input, it throws a RuntimeError.
On further investigations, it seems that it was complex matrix multiplication is not supported.

To Reproduce

Short version:

import torch
x = torch.rand(1, 5, dtype=torch.cfloat)
torch.matmul(x, x.T)
RuntimeError                              Traceback (most recent call last)
<ipython-input-32-ff0d7b177171> in <module>()
      1 import torch
      2 x = torch.rand(1, 5, dtype=torch.cfloat)
----> 3 torch.matmul(x, x.T)

RuntimeError: _th_addmm_out not supported on CPUType for ComplexFloat

Long version (what I was trying to do):

import torch
import torch.nn.functional as F
import torch.nn as nn
import math

class ComplexLinear(nn.Module):
    def __init__(self, in_features, out_features, bias = True):
        super(ComplexLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=torch.complex64))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.complex64))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

layer = ComplexLinear(5, 5)
x = torch.rand(1, 5, dtype=torch.cfloat)
layer(x)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-37-95e898b019a1> in <module>()
     33 layer = ComplexLinear(5, 5)
     34 x = torch.rand(1, 5, dtype=torch.cfloat)
---> 35 layer(x)

2 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1672     if input.dim() == 2 and bias is not None:
   1673         # fused op is marginally faster
-> 1674         ret = torch.addmm(bias, input, weight.t())
   1675     else:
   1676         output = input.matmul(weight.t())

RuntimeError: _th_addmm not supported on CPUType for ComplexFloat

Expected behavior

matrix multiplication should work

Environment

Google Colab

PyTorch version: 1.6.0+cu101
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0

Python version: 3.6 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: 10.1.243
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.18.5
[pip3] torch==1.6.0+cu101
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.3.1
[pip3] torchvision==0.7.0+cu101
[conda] Could not collect

Additional context

I checked the docs. It says that torch.matmul is supported. https://pytorch.org/docs/stable/complex_numbers.html
There was a similar issue on PyTorch forum, it didn't receive any reply. https://discuss.pytorch.org/t/complex-matrix-multiplication/91865

cc @ezyang @anjali411 @dylanbespalko @mruberry @vishwakftw @jianyuh @nikitaved @pearu @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    function requestA request for a new function or the addition of new arguments/modes to an existing function.module: complexRelated to complex number support in PyTorchmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions