1919
2020namespace at { namespace native {
2121
22- // Common code for all FFT functions
23- static inline Tensor _fft (
24- const Tensor &self, int64_t signal_ndim, bool complex_input,
25- const bool complex_output, bool inverse, IntArrayRef signal_sizes,
26- fft_norm_mode normalization, bool onesided);
27-
2822namespace {
2923
3024// Promote inputs to FFT functions
@@ -416,139 +410,6 @@ Tensor fft_ifftshift(const Tensor& x, c10::optional<IntArrayRef> dim_opt) {
416410}
417411
418412
419- // This is a pass-through wrapper function that does the size check and
420- // inferences. The actual forward implementation function is called
421- // at::_fft_with_size which dispatches to _fft_cufft (CUDA) or _fft_mkl (CPU).
422- static inline Tensor _fft (const Tensor &self, const int64_t signal_ndim,
423- const bool complex_input, const bool complex_output,
424- const bool inverse, IntArrayRef signal_sizes,
425- const fft_norm_mode normalization, const bool onesided) {
426-
427- TORCH_CHECK (signal_ndim >= 1 && signal_ndim <= 3 ,
428- " Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=" ,
429- signal_ndim);
430- TORCH_CHECK (at::isFloatingType (self.scalar_type ()),
431- " Expected an input tensor of floating types, but got input=" ,
432- self.toString (), self.sizes ());
433-
434- auto signal_tensor_ndim = signal_ndim + static_cast <int64_t >(complex_input); // add complex dim
435- if (self.dim () < signal_tensor_ndim) {
436- std::ostringstream ss;
437- ss << " Given signal_ndim=" << signal_ndim << " , expected an input tensor "
438- << " of at least " << signal_tensor_ndim << " D" ;
439- if (complex_input) {
440- ss << " (complex input adds an extra dimension)" ;
441- }
442- ss << " , but got input=" << self.toString () << self.sizes ();
443- AT_ERROR (ss.str ());
444- }
445-
446- auto self_shape = self.sizes ();
447- auto batch_ndim = self.dim () - signal_tensor_ndim;
448-
449- Tensor input = self;
450- // flatten the batch dims
451- if (batch_ndim == 0 ) {
452- // slightly faster path for non-batch mode
453- input = input.unsqueeze (0 );
454- } else if (batch_ndim > 1 ) {
455- std::vector<int64_t > flatten_input_shape (signal_tensor_ndim + 1 );
456- std::copy (self_shape.begin () + batch_ndim, self_shape.end (), flatten_input_shape.begin () + 1 );
457- flatten_input_shape[0 ] = -1 ;
458- input = input.reshape (flatten_input_shape);
459-
460- }
461-
462- // now we assume that input is batched as [ B x signal_dims... ]
463-
464- if (complex_input) {
465- TORCH_CHECK (input.size (signal_ndim + 1 ) == 2 ,
466- " Expected an input tensor with a last dimension of size 2 "
467- " representing real + imaginary components, but got input " ,
468- self.toString (), self.sizes ());
469- }
470-
471- // build signal_sizes and output_size
472- TORCH_CHECK (signal_sizes.size () == 0 || static_cast <int64_t >(signal_sizes.size ()) == signal_ndim,
473- " Expected signal_sizes to be empty (default) or of signal_ndim=" ,
474- signal_ndim, " D, but got signal_sizes=" , signal_sizes);
475- std::vector<int64_t > output_sizes (signal_ndim + 1 + static_cast <int64_t >(complex_output));
476- output_sizes[0 ] = input.size (0 ); // batch size
477- std::vector<int64_t > checked_signal_sizes (signal_ndim);
478- for (int64_t i = 0 ; i < signal_ndim; i++) {
479- int64_t input_size = input.size (i + 1 );
480- if (i == signal_ndim - 1 && onesided && complex_input && !complex_output) {
481- // If last dim and complex-to-real onesided, input is only half of
482- // signal, and we need to infer basing on signal_sizes, if given
483- // See native/SpectralOpsUtils.h for detailed description.
484- int64_t inferred_size;
485- if (signal_sizes.size () > 0 ) {
486- inferred_size = infer_ft_complex_to_real_onesided_size (input_size, signal_sizes[i]);
487- } else {
488- inferred_size = infer_ft_complex_to_real_onesided_size (input_size);
489- }
490- checked_signal_sizes[i] = inferred_size;
491- output_sizes[i + 1 ] = inferred_size;
492- } else {
493- if (i == signal_ndim - 1 && onesided && !complex_input && complex_output) {
494- // if last dim and real-to-complex onesided, output should be only
495- // half of the signal, and we need to infer using input_size
496- output_sizes[i + 1 ] = infer_ft_real_to_complex_onesided_size (input_size);
497- } else {
498- output_sizes[i + 1 ] = input_size;
499- }
500- checked_signal_sizes[i] = input_size;
501- TORCH_CHECK (signal_sizes.size () == 0 || signal_sizes[i] == checked_signal_sizes[i],
502- " Expected given signal_sizes=" , signal_sizes," to have same "
503- " shape with input at signal dimension " , i, " , but got "
504- " signal_sizes=" , signal_sizes, " and input=" , self.toString (),
505- self.sizes ());
506- }
507- }
508- if (complex_output) {
509- output_sizes[signal_ndim + 1 ] = 2 ;
510- }
511-
512- Tensor output = at::_fft_with_size (input, signal_ndim, complex_input,
513- complex_output, inverse,
514- checked_signal_sizes,
515- static_cast <int64_t >(normalization),
516- onesided,
517- output_sizes);
518-
519- // unflatten the batch dims
520- if (batch_ndim == 0 ) {
521- // slightly faster path for non-batch mode
522- output = output.squeeze (0 );
523- } else if (batch_ndim > 1 ) {
524- auto output_ndim = self.dim () + static_cast <int64_t >(complex_output) - static_cast <int64_t >(complex_input);
525- std::vector<int64_t > unflatten_output_shape (output_ndim);
526- std::copy (self_shape.begin (), self_shape.begin () + batch_ndim, unflatten_output_shape.begin ());
527- std::copy (output_sizes.begin () + 1 , output_sizes.end (), unflatten_output_shape.begin () + batch_ndim);
528- output = output.reshape (unflatten_output_shape);
529- }
530- return output;
531- }
532-
533- // Wrapper to preserve the historic signature of _fft_with_size
534- // NOTE: This is only used for torchscript backwards compatibility and the new
535- // signature with normalization modes should be used in all other cases
536- Tensor _fft_with_size (const Tensor& input, int64_t signal_ndim,
537- bool complex_input, bool complex_output,
538- bool inverse, IntArrayRef checked_signal_sizes,
539- bool normalized, bool onesided,
540- IntArrayRef output_sizes) {
541- fft_norm_mode norm;
542- if (normalized) {
543- norm = fft_norm_mode::by_root_n;
544- } else {
545- norm = inverse ? fft_norm_mode::by_n : fft_norm_mode::none;
546- }
547- return at::_fft_with_size (
548- input, signal_ndim, complex_input, complex_output, inverse,
549- checked_signal_sizes, static_cast <int64_t >(norm), onesided, output_sizes);
550- }
551-
552413// We call the following methods via CUDA hooks because they are really only
553414// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
554415int64_t _cufft_get_plan_cache_max_size (int64_t device_index) {
@@ -567,36 +428,6 @@ void _cufft_clear_plan_cache(int64_t device_index) {
567428 detail::getCUDAHooks ().cuFFTClearPlanCache (device_index);
568429}
569430
570- static Tensor fft (const Tensor& self, const int64_t signal_ndim, const bool normalized) {
571- return _fft (self, signal_ndim, /* complex_input */ true ,
572- /* complex_output */ true , /* inverse */ false , {},
573- normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
574- /* onesided */ false );
575- }
576-
577- static Tensor ifft (const Tensor& self, const int64_t signal_ndim, const bool normalized) {
578- return _fft (self, signal_ndim, /* complex_input */ true ,
579- /* complex_output */ true , /* inverse */ true , {},
580- normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
581- /* onesided */ false );
582- }
583-
584- static Tensor rfft (const Tensor& self, const int64_t signal_ndim, const bool normalized,
585- const bool onesided) {
586- return _fft (self, signal_ndim, /* complex_input */ false ,
587- /* complex_output */ true , /* inverse */ false , {},
588- normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
589- onesided);
590- }
591-
592- static Tensor irfft (const Tensor& self, const int64_t signal_ndim, const bool normalized,
593- const bool onesided, IntArrayRef signal_sizes) {
594- return _fft (self, signal_ndim, /* complex_input */ true ,
595- /* complex_output */ false , /* inverse */ true , signal_sizes,
596- normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
597- onesided);
598- }
599-
600431template <typename Stream, typename T>
601432static Stream& write_opt (Stream& SS, const optional<T>& value) {
602433 if (value) {
0 commit comments