Skip to content

Complex Numbers Support #33152

@anjali411

Description

@anjali411

Previous discussion: #755

High Priority

Operator support

TH to ATen ports that would greatly help unblock some of the work listed here:

Complex support for real valued torch.nn loss functions
Tasks listed on #46642 along with a sample PR for torch.nn.L1Loss.

Complex Autograd supported but untested:

Enhance testing

  • Migrate the op tests (shown here:

    pytorch/test/test_autograd.py

    Lines 5099 to 5122 in a9e4bb5

    separate_complex_tests = ['view_as_real', 'real', 'imag', 'div', 'pow', 'rsqrt', '__rdiv__', 'add', 'sub']
    # NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly
    # for non-holomorphic functions
    complex_list_filter = []
    # TODO: Add back 'sgn' to complex_list; removed because of Windows test failure with 11.2
    # See: https://github.com/pytorch/pytorch/issues/51980
    if _get_torch_cuda_version() != (11, 2):
    complex_list_filter.append('sgn')
    # allow list for complex
    complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone',
    'repeat', 'expand', 'rot90', 'transpose',
    'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
    'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
    'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'mul',
    '__rmul__', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
    'bmm', 'mv', 'ger', 'diagonal', 'fill_', 'sub',
    'mean', 'inverse', 'solve', 'addcmul',
    'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr',
    'narrow', 'swapaxes', 'swapdims', 'tensor_split', 'tile',
    'baddbmm', 'addbmm', 'addmv'] + complex_list_filter + separate_complex_tests
    ) from method_tests to OpInfo based tests. After this migration is complete, remove the method_test generation logic for complex types:

    pytorch/test/test_autograd.py

    Lines 5131 to 5313 in a9e4bb5

    def add_test(
    name,
    self_size,
    args,
    variant_name='',
    check_ad=(), # only used in test_jit
    dim_args_idx=(),
    skipTestIf=(),
    output_process_fn=lambda x: x,
    kwargs=None):
    kwargs = kwargs if kwargs else {}
    basic_test_name = 'test_' + name
    if variant_name != '':
    basic_test_name += '_' + variant_name
    if name in separate_complex_tests and 'complex' in variant_name:
    run_only_complex = True
    else:
    run_only_complex = False
    for dtype in [torch.double, torch.cdouble]:
    for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
    test_name = basic_test_name
    new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)]
    test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0)
    if dtype.is_complex:
    # TODO: remove this. this is temporary while we ramp up the complex support.
    if name in complex_list:
    if name in separate_complex_tests and 'complex' not in variant_name:
    continue
    if not run_only_complex:
    test_name = test_name + '_complex'
    else:
    continue
    elif run_only_complex:
    continue
    new_args = tuple(new_args)
    # for-loop bodies don't define scopes, so we have to save the variables
    # we want to close over in some way
    def do_test(self, device, dtype=dtype, name=name, self_size=self_size, args=new_args, test_name=test_name,
    output_process_fn=output_process_fn):
    def check(name):
    is_magic_method = name[:2] == '__' and name[-2:] == '__'
    is_inplace = name[-1] == "_" and not is_magic_method
    self_variable = create_input((self_size,), dtype=dtype, device=device)[0][0]
    # FixMe: run grad checks on inplace self
    if is_inplace:
    self_variable.requires_grad = False
    # need to record this because methods can change the size (e.g. unsqueeze)
    args_variable, kwargs_variable = create_input(args, requires_grad=not is_inplace,
    call_kwargs=kwargs, dtype=dtype, device=device)
    self_tensor = deepcopy(self_variable)
    args_tensor = deepcopy(unpack_variables(args_variable))
    if not exclude_tensor_method(name, test_name):
    output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
    output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable)
    if not isinstance(output_tensor, torch.Tensor) and not isinstance(output_tensor, tuple):
    if dtype.is_complex:
    output_tensor = torch.tensor((output_tensor, ), dtype=torch.cfloat, device=device)
    else:
    output_tensor = torch.tensor((output_tensor, ), dtype=torch.float, device=device)
    self.assertEqual(unpack_variables(output_variable), output_tensor)
    # TODO: check that both have changed after adding all inplace ops
    def fn(*inputs):
    output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
    return output_process_fn(output)
    if not is_inplace and name not in EXCLUDE_GRADCHECK:
    check_batched_grad = test_name not in EXCLUDE_BATCHED_GRAD_TESTS
    run_grad_and_gradgrad_checks(self, name, test_name, fn,
    output_variable, (self_variable,) + args_variable,
    check_batched_grad=check_batched_grad)
    # functional interface tests
    torch_fn = getattr_qualified(torch, name)
    if torch_fn is not None and name not in EXCLUDE_FUNCTIONAL:
    def fn(*inputs):
    output = torch_fn(*inputs, **kwargs)
    return output_process_fn(output)
    f_args_variable = (self_variable,) + args_variable
    f_args_tensor = (self_tensor,) + args_tensor
    # could run the gradchecks again, but skip since we did it for the methods above.
    run_gradcheck = exclude_tensor_method(name, test_name) and not is_inplace and name not in EXCLUDE_GRADCHECK
    run_functional_checks(self, test_name, name, fn,
    run_gradcheck, f_args_variable, f_args_tensor)
    # check for correct type of input and input.grad
    if not is_inplace:
    self_variable = create_input((self_size,), requires_grad=True, dtype=dtype)[0][0]
    args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs, dtype=dtype)
    if hasattr(self_variable, name):
    attribute_result = getattr(self_variable, name)
    if callable(attribute_result):
    output_variable = attribute_result(*args_variable, **kwargs_variable)
    else:
    self.assertTrue(len(args_variable) == 0)
    self.assertTrue(len(kwargs_variable) == 0)
    output_variable = attribute_result
    else:
    self_and_args_variable = (self_variable,) + args_variable
    output_variable = torch_fn(*self_and_args_variable, **kwargs_variable)
    if isinstance(output_variable, torch.autograd.Variable):
    if output_variable.is_sparse:
    rand = randn_like(output_variable.to_dense()).to_sparse()
    else:
    rand = randn_like(output_variable)
    output_variable.backward(rand)
    self.assertTrue(type(self_variable) == type(self_variable.grad))
    self.assertTrue(self_variable.size() == self_variable.grad.size())
    # compare grads to inplace grads
    inplace_name = name + '_'
    # can't broadcast inplace to left hand side
    skip_inplace = ('broadcast_lhs' in test_name or
    'broadcast_all' in test_name or
    'atanh' in test_name or
    'acosh' in test_name or
    'asinh' in test_name or
    'abs_complex' in test_name or
    'abs_scalar_complex' in test_name)
    if hasattr(torch.ones(1), inplace_name) and not skip_inplace:
    output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable)
    if not isinstance(output_variable, tuple):
    output_variable = (output_variable,)
    inplace_self_variable = deepcopy(self_variable)
    inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
    for i in (inplace_self_variable,))
    inplace_args_variable = deepcopy(args_variable)
    inplace_args_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i
    for i in inplace_args_variable)
    inplace_output_variable = (
    getattr(inplace_self_variable_copy[0], inplace_name)(*inplace_args_variable_copy,
    **kwargs_variable))
    if not isinstance(inplace_output_variable, tuple):
    inplace_output_variable = (inplace_output_variable,)
    self.assertEqual(inplace_output_variable, output_variable)
    # Check that gradient is the same
    for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable,
    (self_variable,) + args_variable):
    if not isinstance(inp_i, torch.Tensor):
    assert not isinstance(i, torch.Tensor)
    continue
    if inp_i.grad is not None:
    with torch.no_grad():
    inp_i.grad.zero_()
    if i.grad is not None:
    with torch.no_grad():
    i.grad.zero_()
    for i_o, o in zip(inplace_output_variable, output_variable):
    if dtype.is_complex:
    grad = randn_like(i_o).to(torch.cdouble)
    else:
    grad = randn_like(i_o).double()
    i_o.backward(grad)
    o.backward(grad)
    for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable,
    (self_variable,) + args_variable):
    if not isinstance(inp_i, torch.Tensor):
    continue
    self.assertEqual(inp_i.grad, i.grad)
    check(name)
    inplace_name = name + '_'
    # can't broadcast inplace to left hand side
    broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name
    # skip C -> R inplace tests
    skip_c_to_r_inplace = 'abs_complex' in test_name or 'abs_scalar_complex' in test_name
    skip_inplace = broadcast_skip_inplace or skip_c_to_r_inplace
    if hasattr(torch.ones(1), inplace_name) and not skip_inplace:
    check(inplace_name)
    assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
    for skip in skipTestIf:
    do_test = skip(do_test)
    setattr(TestAutogradDeviceType, test_name, do_test)
  • Add R->C tests for complex functions. (contact @anjali411 if you'd like to work on this task)

Complex Autograd support for the following ops supported for complex:


Linear Algebra Ops

Other ops:

Spectral Op Migration: #42175

complex views

[Factory Functions]

Other functions:

Complex Number Support for torch.nn.distributed: #45760

Complex Autograd Guide:

To understand and obtain the formula for complex derivatives, check out: https://pytorch.org/docs/master/notes/autograd.html#what-are-complex-derivatives

The list of operators for which complex autograd is supported and tested can be found here:

GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
. If an operator has an entry in the file derivatives.yaml, then to enable complex backward for that operator, it would have to be added to GRADIENT_IMPLEMENTED_FOR_COMPLEX in gen_variable_type.py.

To run a common_methods test for complex, add an entry here:

complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone',

To run a common_methods test only for complex, add an entry here:
separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos', 'div', 'log',

Autograd tasks:

Other tasks:

Discussions:

  1. Comparison ops for complex numbers

Numpy parity:

c10::complex tracker: #35284 (comment)

cc @ezyang @anjali411 @dylanbespalko @mruberry

Metadata

Metadata

Assignees

Labels

complex_autogradmodule: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions