-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Previous discussion: #755
High Priority
Operator support
- [New operator]
torch.linalg.adjoint() -
torch.linalg.eig(both CPU and CUDA, waiting on TH to ATen port) - Add kernel for vectorization on CPU for
torch.complexandtorch.polar -
torch.arange - Deprecate the current linspace, logspace automatic dtype inference and get rid of the unnecessary warning
- Conjugate transpose variants of linear algebra operators for e.g., see here an issue for
torch.triangular_solve -
torch.special
TH to ATen ports that would greatly help unblock some of the work listed here:
-
modeMigratemodefrom the TH to Aten (CPU) #24731, Migrate_modefrom the TH to Aten (CUDA) #24526 -
renormMigraterenormandrenorm_from the TH to Aten (CUDA) #24616 -
put_Migrateput_from the TH to Aten (CPU) #24751, Migrateput_from the TH to Aten (CUDA) #24614 -
takeMigratetakefrom the TH to Aten (CUDA) #24640, Migratetakefrom the TH to Aten (CPU) #24772 -
nonzero (CPU)Migratenonzerofrom the TH to Aten (CPU) #24745 -
index_fill_Migrateindex_fill_from the TH to Aten (CUDA) #24577, Migrateindex_fill_from the TH to Aten (CPU) #24714 -
index_copyMigrate_index_copy_from the TH to Aten (CUDA) #24523, Migrate_index_copy_from the TH to Aten (CPU) #24670 -
masked_scatter -
masked_fillMigratemasked_fillfrom TH to ATen (CUDA) #49543
Complex support for real valued torch.nn loss functions
Tasks listed on #46642 along with a sample PR for torch.nn.L1Loss.
Complex Autograd supported but untested:
-
torch.dist -
torch.masked_fill -
torch.Tensor.masked_scatter_ -
torch.Tensor.scatter_
Enhance testing
- Migrate the op tests (shown here: ) from method_tests to
Lines 5099 to 5122 in a9e4bb5
separate_complex_tests = ['view_as_real', 'real', 'imag', 'div', 'pow', 'rsqrt', '__rdiv__', 'add', 'sub'] # NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly # for non-holomorphic functions complex_list_filter = [] # TODO: Add back 'sgn' to complex_list; removed because of Windows test failure with 11.2 # See: https://github.com/pytorch/pytorch/issues/51980 if _get_torch_cuda_version() != (11, 2): complex_list_filter.append('sgn') # allow list for complex complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone', 'repeat', 'expand', 'rot90', 'transpose', 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'mul', '__rmul__', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul', 'bmm', 'mv', 'ger', 'diagonal', 'fill_', 'sub', 'mean', 'inverse', 'solve', 'addcmul', 'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', 'narrow', 'swapaxes', 'swapdims', 'tensor_split', 'tile', 'baddbmm', 'addbmm', 'addmv'] + complex_list_filter + separate_complex_tests OpInfobased tests. After this migration is complete, remove the method_test generation logic for complex types:Lines 5131 to 5313 in a9e4bb5
def add_test( name, self_size, args, variant_name='', check_ad=(), # only used in test_jit dim_args_idx=(), skipTestIf=(), output_process_fn=lambda x: x, kwargs=None): kwargs = kwargs if kwargs else {} basic_test_name = 'test_' + name if variant_name != '': basic_test_name += '_' + variant_name if name in separate_complex_tests and 'complex' in variant_name: run_only_complex = True else: run_only_complex = False for dtype in [torch.double, torch.cdouble]: for dim_perm in product([-1, 1], repeat=len(dim_args_idx)): test_name = basic_test_name new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)] test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0) if dtype.is_complex: # TODO: remove this. this is temporary while we ramp up the complex support. if name in complex_list: if name in separate_complex_tests and 'complex' not in variant_name: continue if not run_only_complex: test_name = test_name + '_complex' else: continue elif run_only_complex: continue new_args = tuple(new_args) # for-loop bodies don't define scopes, so we have to save the variables # we want to close over in some way def do_test(self, device, dtype=dtype, name=name, self_size=self_size, args=new_args, test_name=test_name, output_process_fn=output_process_fn): def check(name): is_magic_method = name[:2] == '__' and name[-2:] == '__' is_inplace = name[-1] == "_" and not is_magic_method self_variable = create_input((self_size,), dtype=dtype, device=device)[0][0] # FixMe: run grad checks on inplace self if is_inplace: self_variable.requires_grad = False # need to record this because methods can change the size (e.g. unsqueeze) args_variable, kwargs_variable = create_input(args, requires_grad=not is_inplace, call_kwargs=kwargs, dtype=dtype, device=device) self_tensor = deepcopy(self_variable) args_tensor = deepcopy(unpack_variables(args_variable)) if not exclude_tensor_method(name, test_name): output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) output_tensor = getattr(self_tensor, name)(*args_tensor, **kwargs_variable) if not isinstance(output_tensor, torch.Tensor) and not isinstance(output_tensor, tuple): if dtype.is_complex: output_tensor = torch.tensor((output_tensor, ), dtype=torch.cfloat, device=device) else: output_tensor = torch.tensor((output_tensor, ), dtype=torch.float, device=device) self.assertEqual(unpack_variables(output_variable), output_tensor) # TODO: check that both have changed after adding all inplace ops def fn(*inputs): output = getattr(inputs[0], name)(*inputs[1:], **kwargs) return output_process_fn(output) if not is_inplace and name not in EXCLUDE_GRADCHECK: check_batched_grad = test_name not in EXCLUDE_BATCHED_GRAD_TESTS run_grad_and_gradgrad_checks(self, name, test_name, fn, output_variable, (self_variable,) + args_variable, check_batched_grad=check_batched_grad) # functional interface tests torch_fn = getattr_qualified(torch, name) if torch_fn is not None and name not in EXCLUDE_FUNCTIONAL: def fn(*inputs): output = torch_fn(*inputs, **kwargs) return output_process_fn(output) f_args_variable = (self_variable,) + args_variable f_args_tensor = (self_tensor,) + args_tensor # could run the gradchecks again, but skip since we did it for the methods above. run_gradcheck = exclude_tensor_method(name, test_name) and not is_inplace and name not in EXCLUDE_GRADCHECK run_functional_checks(self, test_name, name, fn, run_gradcheck, f_args_variable, f_args_tensor) # check for correct type of input and input.grad if not is_inplace: self_variable = create_input((self_size,), requires_grad=True, dtype=dtype)[0][0] args_variable, kwargs_variable = create_input(args, requires_grad=False, call_kwargs=kwargs, dtype=dtype) if hasattr(self_variable, name): attribute_result = getattr(self_variable, name) if callable(attribute_result): output_variable = attribute_result(*args_variable, **kwargs_variable) else: self.assertTrue(len(args_variable) == 0) self.assertTrue(len(kwargs_variable) == 0) output_variable = attribute_result else: self_and_args_variable = (self_variable,) + args_variable output_variable = torch_fn(*self_and_args_variable, **kwargs_variable) if isinstance(output_variable, torch.autograd.Variable): if output_variable.is_sparse: rand = randn_like(output_variable.to_dense()).to_sparse() else: rand = randn_like(output_variable) output_variable.backward(rand) self.assertTrue(type(self_variable) == type(self_variable.grad)) self.assertTrue(self_variable.size() == self_variable.grad.size()) # compare grads to inplace grads inplace_name = name + '_' # can't broadcast inplace to left hand side skip_inplace = ('broadcast_lhs' in test_name or 'broadcast_all' in test_name or 'atanh' in test_name or 'acosh' in test_name or 'asinh' in test_name or 'abs_complex' in test_name or 'abs_scalar_complex' in test_name) if hasattr(torch.ones(1), inplace_name) and not skip_inplace: output_variable = getattr(self_variable, name)(*args_variable, **kwargs_variable) if not isinstance(output_variable, tuple): output_variable = (output_variable,) inplace_self_variable = deepcopy(self_variable) inplace_self_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i for i in (inplace_self_variable,)) inplace_args_variable = deepcopy(args_variable) inplace_args_variable_copy = tuple(i.clone() if isinstance(i, torch.Tensor) else i for i in inplace_args_variable) inplace_output_variable = ( getattr(inplace_self_variable_copy[0], inplace_name)(*inplace_args_variable_copy, **kwargs_variable)) if not isinstance(inplace_output_variable, tuple): inplace_output_variable = (inplace_output_variable,) self.assertEqual(inplace_output_variable, output_variable) # Check that gradient is the same for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable, (self_variable,) + args_variable): if not isinstance(inp_i, torch.Tensor): assert not isinstance(i, torch.Tensor) continue if inp_i.grad is not None: with torch.no_grad(): inp_i.grad.zero_() if i.grad is not None: with torch.no_grad(): i.grad.zero_() for i_o, o in zip(inplace_output_variable, output_variable): if dtype.is_complex: grad = randn_like(i_o).to(torch.cdouble) else: grad = randn_like(i_o).double() i_o.backward(grad) o.backward(grad) for inp_i, i in zip((inplace_self_variable,) + inplace_args_variable, (self_variable,) + args_variable): if not isinstance(inp_i, torch.Tensor): continue self.assertEqual(inp_i.grad, i.grad) check(name) inplace_name = name + '_' # can't broadcast inplace to left hand side broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name # skip C -> R inplace tests skip_c_to_r_inplace = 'abs_complex' in test_name or 'abs_scalar_complex' in test_name skip_inplace = broadcast_skip_inplace or skip_c_to_r_inplace if hasattr(torch.ones(1), inplace_name) and not skip_inplace: check(inplace_name) assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name for skip in skipTestIf: do_test = skip(do_test) setattr(TestAutogradDeviceType, test_name, do_test) - Add
R->Ctests for complex functions. (contact @anjali411 if you'd like to work on this task)
Complex Autograd support for the following ops supported for complex:
-
torch.sigmoid -
torch.eig -
torch.lu -
torch.cross(Also add tests for function correctness) -
torch.cumsum -
torch.cumprod -
torch.linalg.det -
torch.Tensor.index -
torch.lerp -
torch.prod - (waiting on TH to ATen port)
torch.Tensor.put_ -
rsub -
torch.symeig -
index_put -
torch.unfold -
torch.diag -
torch.Tensor.masked_fill_ -
torch.trace -
torch.polar
Linear Algebra Ops
- torch.logdet
- torch.trapz
- torch.pca_lowrank
- torch.svd_lowrank
- torch.ormqr
- torch.orgqr
- torch.matrix_power
-
torch.cholesky_inverse -
torch.chain_matmul -
torch.ger / torch.outer - torch.solve [CUDA]
- torch.bmm [CUDA, waiting on TH to ATen port]
- torch.lu_solve
- torch.cholesky_solve [CUDA]
- torch.baddbmm
Other ops:
-
torch.isnanandtorch.isfinite -
torch.linspace -
torch.logspace - torch.trace
- torch.mean
- torch.dot [CPU and CUDA]
- torch.vdot
- torch.bmm [CPU]
- torch.symeig [CPU and CUDA]
- torch.addmv (added here and here)
- torch.cholesky [CUDA]
- torch.svd [CUDA]
- torch.qr [CUDA]
- torch.lu, torch.lu_unpack [CUDA]
- torch.det
- torch.inverse [CUDA]
- torch.triangular_solve [CUDA]
Spectral Op Migration: #42175
complex views
[Factory Functions]
-
torch.complex , torch.polar 35312
-
torch.rand
-
torch.randn
-
torch.from_numpy(complex_array)
blocked on: Migrateset_from the TH to Aten (CPU) #24759 Migrateset_from the TH to Aten (CUDA) #24623
Other functions:
- torch.sgn
- torch.index_select [CUDA]
- [Disabled] torch.remainder
- [Disabled] torch.clamp
Complex Number Support for torch.nn.distributed: #45760
Complex Autograd Guide:
To understand and obtain the formula for complex derivatives, check out: https://pytorch.org/docs/master/notes/autograd.html#what-are-complex-derivatives
The list of operators for which complex autograd is supported and tested can be found here:
pytorch/tools/autograd/gen_variable_type.py
Line 154 in 72bc3d9
| GRADIENT_IMPLEMENTED_FOR_COMPLEX = { |
derivatives.yaml, then to enable complex backward for that operator, it would have to be added to GRADIENT_IMPLEMENTED_FOR_COMPLEX in gen_variable_type.py.
To run a common_methods test for complex, add an entry here:
Line 4927 in b61671c
| complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone', |
To run a common_methods test only for complex, add an entry here:
Line 4920 in b61671c
| separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos', 'div', 'log', |
Autograd tasks:
-
Gradcheck logic for
C -> C, C -> R, R -> Cfunctions. (Complex gradcheck logic #43208 ) -
Disable complex autograd for operators not tested for complex. (Add allowlist for complex backward #45461 )
Other tasks:
- type promotion (eg. Float/Long/... tensor + a complex number = complex tensor) 33780 [https://github.com/Multiplying a complex scalar to a non-complex tensor outputs tensor with only real values #33780]
- Python API for ComplexFloat and ComplexDouble Storage
- Disable complex sign (see torch.sign doesn't work for complex tensors #36323)
- Disable comparison ops for complex(clamp, sort, min, max...) (see Min and Max with complex inputs exhibit behavior incompatible with NumPy #36374, Inconsistency between torch.clamp() and numpy.clip() behavior for complex numbers #33568, and Comparison ops for Complex Tensors #36444)
- Implement isclose for complex (see torch.isclose for complex tensors #36462, Implement Complex Support that allows Complex Tests to run in test_torch.py #36029, and torch.isclose() undocumented #35471) 36456
- Complex Printing for scientific notation 38285 33494
- testing in test_torch.py: Implement Complex Support that allows Complex Tests to run in test_torch.py #36029
- conj() is not supported on real cuda tensors
Discussions:
Numpy parity:
-
Inconsistency between torch.abs and np.abs 33567
-
torch.abs(complex) is divergent from NumPy on vectorized NaN values
c10::complex tracker: #35284 (comment)