@@ -119,25 +119,12 @@ Tensor fft_c2r(Tensor input, c10::optional<int64_t> n_opt,
119119 if (n_opt) {
120120 input = resize_fft_input (input, dim, n/2 + 1 );
121121 }
122- // _fft only operates on the last dim, so transpose the selected dim to the end
123- const bool must_transpose = (dim != input_dim - 1 );
124- if (must_transpose) {
125- input = at::transpose (input, -1 , dim);
126- }
127122 const auto norm = norm_from_string (norm_str, forward);
128123 if (forward) {
129124 // FIXME: _fft does not support complex_output=false with inverse=false
130125 input = at::conj (input);
131126 }
132- auto out = _fft (at::view_as_real (input),
133- /* signal_ndim=*/ 1 , /* complex_input=*/ true ,
134- /* complex_output=*/ false , /* inverse=*/ true ,
135- /* signal_sizes=*/ {n}, /* normalization=*/ norm,
136- /* onesided=*/ true );
137- if (must_transpose) {
138- out = at::transpose (out, -1 , dim);
139- }
140- return out;
127+ return at::_fft_c2r (input, dim, static_cast <int64_t >(norm), n);
141128}
142129
143130// Real to complex FFT
@@ -153,22 +140,11 @@ Tensor fft_r2c(Tensor input, c10::optional<int64_t> n_opt,
153140 if (n_opt) {
154141 input = resize_fft_input (input, dim, n);
155142 }
156- // _fft only operates on the last dim, so transpose the selected dim to the end
157- const bool must_transpose = (dim != input_dim - 1 );
158- if (must_transpose) {
159- input = at::transpose (input, -1 , dim);
160- }
143+
161144 const auto norm = norm_from_string (norm_str, forward);
162- auto out = _fft (input, /* signal_ndim=*/ 1 , /* complex_input=*/ false ,
163- /* complex_output=*/ true , /* inverse=*/ false ,
164- /* signal_sizes=*/ {n}, /* normalization=*/ norm,
165- /* onesided=*/ onesided);
166- out = at::view_as_complex (out);
167- if (must_transpose) {
168- out = at::transpose (out, -1 , dim);
169- }
145+ auto out = at::_fft_r2c (input, dim, static_cast <int64_t >(norm), onesided);
170146 if (!forward) {
171- // FIXME: _fft does not support complex_input=false with inverse=true
147+ // FIXME: _fft_r2c doesn't support native r2c IFFT
172148 out = at::conj (out);
173149 }
174150 return out;
@@ -186,22 +162,8 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
186162 if (n_opt) {
187163 input = resize_fft_input (input, dim, n);
188164 }
189- // _fft only operates on the last dim, so transpose the selected dim to the end
190- const bool must_transpose = (dim != input_dim - 1 );
191- if (must_transpose) {
192- input = at::transpose (input, -1 , dim);
193- }
194165 const auto norm = norm_from_string (norm_str, forward);
195- auto out = _fft (at::view_as_real (input),
196- /* signal_ndim=*/ 1 , /* complex_input=*/ true ,
197- /* complex_output=*/ true , /* inverse=*/ !forward,
198- /* signal_sizes=*/ {}, /* normalization=*/ norm,
199- /* onesided=*/ false );
200- out = at::view_as_complex (out);
201- if (must_transpose) {
202- out = at::transpose (out, -1 , dim);
203- }
204- return out;
166+ return at::_fft_c2c (input, dim, static_cast <int64_t >(norm), forward);
205167}
206168
207169// Dimensions to transform, and the signal shape in those dimensions
@@ -277,44 +239,12 @@ Tensor fftn_c2c(
277239 const Tensor& input, IntArrayRef shape, IntArrayRef dim,
278240 c10::optional<std::string> norm_str, bool forward) {
279241 TORCH_CHECK (input.is_complex (), " Expected a complex input tensor to FFT" );
280- const auto input_dim = input.dim ();
281-
282242 Tensor x = resize_fft_input (input, dim, shape);
283- x = at::view_as_real (x);
284-
285- const int64_t transform_ndim = dim.size ();
286243 const auto norm = norm_from_string (norm_str, forward);
287- // _fft_with_size only supports 3 dimensions being transformed at a time.
288- // This limit is inherited from cuFFT.
289- constexpr int64_t max_signal_ndim = 3 ;
290-
291- // Transform n dimensions, up to 3 at a time
292- // TODO: rewrite _fft_with_size to transform more than 3 dimensions at once.
293- for (int64_t i = 0 ; i < transform_ndim; i += max_signal_ndim) {
294- const int64_t signal_ndim = std::min (transform_ndim - i, max_signal_ndim);
295- DimVector source_dim (signal_ndim);
296- DimVector dest_dim (signal_ndim);
297-
298- for (int64_t j = 0 ; j < signal_ndim; ++j) {
299- source_dim[j] = dim[i + j];
300- dest_dim[j] = j + (input_dim - signal_ndim);
301- }
302-
303- // _fft operates on up-to the last 3 dims, so move selected dims to the end
304- x = at::movedim (x, source_dim, dest_dim);
305-
306- x = _fft (x, signal_ndim, /* complex_input=*/ true , /* complex_output=*/ true ,
307- /* inverse=*/ !forward, /* signal_sizes=*/ {}, /* normalization=*/ norm,
308- /* onesided=*/ false );
309-
310- // Move transform dims back to their original order
311- x = at::movedim (x, dest_dim, source_dim);
312- }
313-
314- return at::view_as_complex (x);
244+ return at::_fft_c2c (x, dim, static_cast <int64_t >(norm), forward);
315245}
316246
317- }
247+ } // namespace (anonymous)
318248
319249// torch.fft.fft, analogous to NumPy's numpy.fft.fft
320250Tensor fft_fft (const Tensor& self, c10::optional<int64_t > n, int64_t dim,
@@ -370,44 +300,36 @@ Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s,
370300
371301Tensor fft_rfftn (const Tensor& self, c10::optional<IntArrayRef> s,
372302 c10::optional<IntArrayRef> dim,
373- c10::optional<std::string> norm) {
303+ c10::optional<std::string> norm_str) {
304+ TORCH_CHECK (!self.is_complex (), " rfftn expects a real-valued input tensor, but got " , self.scalar_type ());
374305 auto desc = canonicalize_fft_shape_and_dim_args (self, s, dim);
375306 TORCH_CHECK (desc.shape .size () > 0 , " rfftn must transform at least one axis" );
376-
377- const auto last_dim = desc.dim .back ();
378- const auto last_shape = desc.shape .back ();
379- desc.shape .pop_back ();
380- desc.dim .pop_back ();
381-
382- // rfft on last dim to get hermitian complex shape
383- auto x = native::fft_rfft (self, last_shape, last_dim, norm);
384- // Normal fft on remaining dims
385- return fftn_c2c (x, desc.shape , desc.dim , norm, /* forward=*/ true );
307+ Tensor input = promote_tensor_fft (self, /* require_complex=*/ false );
308+ Tensor x = resize_fft_input (input, desc.dim , desc.shape );
309+ const auto norm = norm_from_string (norm_str, /* forward=*/ true );
310+ return at::_fft_r2c (x, desc.dim , static_cast <int64_t >(norm), /* onesided=*/ true );
386311}
387312
388313Tensor fft_irfftn (const Tensor& self, c10::optional<IntArrayRef> s,
389314 c10::optional<IntArrayRef> dim,
390- c10::optional<std::string> norm ) {
315+ c10::optional<std::string> norm_str ) {
391316 auto desc = canonicalize_fft_shape_and_dim_args (self, s, dim);
392317 TORCH_CHECK (desc.shape .size () > 0 , " irfftn must transform at least one axis" );
393318
394- const auto last_dim = desc.dim .back ();
395- const auto last_shape = [&]() -> c10::optional<int64_t > {
396- // If shape is defaulted in the last dimension,
397- // pass nullopt to irfft and let it calculate the default size
319+ const auto last_dim_size = [&] {
320+ // Fixup default shape handling in the last dimension,
398321 if (!s.has_value () || (s->back () == -1 )) {
399- return c10::nullopt ;
322+ const auto last_dim = desc.dim .back ();
323+ return 2 * (self.sizes ()[last_dim] - 1 );
400324 }
401325 return desc.shape .back ();
402326 }();
403- desc.shape .pop_back ();
404- desc.dim .pop_back ();
405-
406- // Normal ifft for all but last dim
407- Tensor x = promote_tensor_fft (self, /* require_complex=*/ true );
408- x = fftn_c2c (x, desc.shape , desc.dim , norm, /* forward=*/ false );
409- // Then 1d irfft on last dim to get real output
410- return native::fft_irfft (x, last_shape, last_dim, norm);
327+ desc.shape .back () = last_dim_size / 2 + 1 ;
328+
329+ Tensor input = promote_tensor_fft (self, /* require_complex=*/ true );
330+ Tensor x = resize_fft_input (input, desc.dim , desc.shape );
331+ const auto norm = norm_from_string (norm_str, /* forward=*/ false );
332+ return at::_fft_c2r (x, desc.dim , static_cast <int64_t >(norm), last_dim_size);
411333}
412334
413335Tensor fft_fft2 (const Tensor& self, c10::optional<IntArrayRef> s,
0 commit comments