Skip to content

Complex autograd is supported for torch.Tensor.masked_scatter_ but untested #53608

@anjali411

Description

@anjali411

These skips (which were added at the time since masked_fill still hadn't been ported on CUDA) should be removed now that the port is complete and we can add dispatch for complex types for masked_fill kernel on CUDA.

>>> dest = torch.randn(10, dtype=torch.cdouble, device='cuda', requires_grad=True).clone()
>>> src = torch.randn(10, dtype=torch.cdouble, device='cuda', requires_grad=True)
>>> mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=torch.bool, device='cuda')
>>> dest.masked_scatter_(mask, src)
tensor([ 0.4441-0.2455j,  0.2870-0.1043j, -0.6185+0.6146j, -0.6428+0.5322j,
         0.7472+0.1586j, -0.5039+2.1577j,  0.1907+0.7764j,  0.2066+0.6968j,
        -0.2815+0.2045j, -1.3827+1.7941j], device='cuda:0',
        dtype=torch.complex128, grad_fn=<MaskedScatterBackward>)

cc @ezyang @anjali411 @dylanbespalko @mruberry @aocsa

Metadata

Metadata

Assignees

Labels

complex_autogradmodule: complexRelated to complex number support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions