-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Tracking issue for tasks related to the torch.fft namespace, analogous to NumPy's numpy.fft namespace and SciPy's scipy.fft namespace.
PyTorch already has fft functions (fft, ifft, rfft , irfft, stft, istft), but they're inconsistent with NumPy and don't accept complex tensor inputs. The torch.fft namespace should be consistent with NumPy and SciPy where possible, plus provide a path towards removing PyTorch's existing fft functions in the 1.8 release (deprecating them in 1.7).
While adding the torch.fft namespace infrastructure and deprecating PyTorch's current fft-related functions are the top priorities, PyTorch is also missing many helpful functions, listed below, which should (eventually) be added to the new namespace, too.
Tasks:
- Write doc preamble to fft module
- Write blogpost about the torch.fft module
- Write tutorial
- Create forum group
Completed:
- Fix erronenous stft() warning (see comment below) (fixed in Use new FFT operators in stft #47601)
- Remove torch.fft() (and torch.Tensor.fft() and torch.rfft() and torch.Tensor.rfft()) and import the torch.fft module by default (Remove deprecated spectral ops from torch namespace #48594)
- Review fft architecture for (1) functions used are in ATen and (2) functions actually support complex values (so complex-specific kernels can be used) (New FFT operators in Improve torch.fft n-dimensional transforms #46911 use complex throughout)
- Rewrite
_fft_with_sizeto handle transforming arbitrary dimensions without transposing/cloning (Improve torch.fft n-dimensional transforms #46911 still requires cloning, but this is required for best performance) - Review operator performance
- Related issues: CUDA irfft may be doing unnecessary cloning of input #38413
- Add OpInfo based tests
- Implement torch function overrides for the fft module
- Add support for out arguments
- Implement torch.fft.fft2()
- Implement torch.fft.iff2()
- Implement torch.fft.rfft2()
- Implement torch.fft.irfft2()
- Investigate gradgrad and ensure
gradgradcheckpasses - stft crash (stft does not consistently check window device #30865)
- Implement torch.fft.fft()
- Implement torch.fft.ifft()
- Implement torch.fft.rfft()
- Implement torch.fft.irfft()
- Update torch.stft() and torch.istft() to handle complex tensors (like librosa does) and document librosa-compat (not SciPy)
- Create test_spectral_ops.py (Creates spectral ops test suite #42157)
- Create torch.fft namespace (Adds fft namespace #41911)
- Implement torch.fft.fftn()
- Implement torch.fft.ifftn()
- Implement torch.fft.fftfreq()
- Implement torch.fft.rfftfreq()
- Implement torch.fft.fftshift()
- Implement torch.fft.ifftshift()
- Deprecate torch.fft()
- Deprecate torch.ifft()
- Deprecate torch.rfft()
- Deprecate torch.irfft()
- Deprecate torch.stft() and torch.istft() returning non-complex tensors mimicking complex tensors
- pytorch istft runs slower than torchaudio istft especially at higher n_fft #42213, pytorch istft runs slower than torchaudio istft
cc @ezyang @gchanan @zou3519 @anjali411 @dylanbespalko @mruberry @rgommers @peterbell10