Skip to content

Commit fc0a3a1

Browse files
peterbell10facebook-github-bot
authored andcommitted
Improve torch.fft n-dimensional transforms (#46911)
Summary: Pull Request resolved: #46911 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D25420647 Pulled By: mruberry fbshipit-source-id: bf7e6a2ec41f9f95ffb05c128ee0f3297e34aae2
1 parent f5e9ffb commit fc0a3a1

File tree

9 files changed

+535
-106
lines changed

9 files changed

+535
-106
lines changed

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 24 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -119,25 +119,12 @@ Tensor fft_c2r(Tensor input, c10::optional<int64_t> n_opt,
119119
if (n_opt) {
120120
input = resize_fft_input(input, dim, n/2 + 1);
121121
}
122-
// _fft only operates on the last dim, so transpose the selected dim to the end
123-
const bool must_transpose = (dim != input_dim - 1);
124-
if (must_transpose) {
125-
input = at::transpose(input, -1, dim);
126-
}
127122
const auto norm = norm_from_string(norm_str, forward);
128123
if (forward) {
129124
// FIXME: _fft does not support complex_output=false with inverse=false
130125
input = at::conj(input);
131126
}
132-
auto out = _fft(at::view_as_real(input),
133-
/*signal_ndim=*/1, /*complex_input=*/true,
134-
/*complex_output=*/false, /*inverse=*/true,
135-
/*signal_sizes=*/{n}, /*normalization=*/norm,
136-
/*onesided=*/true);
137-
if (must_transpose) {
138-
out = at::transpose(out, -1, dim);
139-
}
140-
return out;
127+
return at::_fft_c2r(input, dim, static_cast<int64_t>(norm), n);
141128
}
142129

143130
// Real to complex FFT
@@ -153,22 +140,11 @@ Tensor fft_r2c(Tensor input, c10::optional<int64_t> n_opt,
153140
if (n_opt) {
154141
input = resize_fft_input(input, dim, n);
155142
}
156-
// _fft only operates on the last dim, so transpose the selected dim to the end
157-
const bool must_transpose = (dim != input_dim - 1);
158-
if (must_transpose) {
159-
input = at::transpose(input, -1, dim);
160-
}
143+
161144
const auto norm = norm_from_string(norm_str, forward);
162-
auto out = _fft(input, /*signal_ndim=*/1, /*complex_input=*/false,
163-
/*complex_output=*/true, /*inverse=*/false,
164-
/*signal_sizes=*/{n}, /*normalization=*/norm,
165-
/*onesided=*/onesided);
166-
out = at::view_as_complex(out);
167-
if (must_transpose) {
168-
out = at::transpose(out, -1, dim);
169-
}
145+
auto out = at::_fft_r2c(input, dim, static_cast<int64_t>(norm), onesided);
170146
if (!forward) {
171-
// FIXME: _fft does not support complex_input=false with inverse=true
147+
// FIXME: _fft_r2c doesn't support native r2c IFFT
172148
out = at::conj(out);
173149
}
174150
return out;
@@ -186,22 +162,8 @@ Tensor fft_c2c(Tensor input, c10::optional<int64_t> n_opt,
186162
if (n_opt) {
187163
input = resize_fft_input(input, dim, n);
188164
}
189-
// _fft only operates on the last dim, so transpose the selected dim to the end
190-
const bool must_transpose = (dim != input_dim - 1);
191-
if (must_transpose) {
192-
input = at::transpose(input, -1, dim);
193-
}
194165
const auto norm = norm_from_string(norm_str, forward);
195-
auto out = _fft(at::view_as_real(input),
196-
/*signal_ndim=*/1, /*complex_input=*/true,
197-
/*complex_output=*/true, /*inverse=*/!forward,
198-
/*signal_sizes=*/{}, /*normalization=*/norm,
199-
/*onesided=*/false);
200-
out = at::view_as_complex(out);
201-
if (must_transpose) {
202-
out = at::transpose(out, -1, dim);
203-
}
204-
return out;
166+
return at::_fft_c2c(input, dim, static_cast<int64_t>(norm), forward);
205167
}
206168

207169
// Dimensions to transform, and the signal shape in those dimensions
@@ -277,44 +239,12 @@ Tensor fftn_c2c(
277239
const Tensor& input, IntArrayRef shape, IntArrayRef dim,
278240
c10::optional<std::string> norm_str, bool forward) {
279241
TORCH_CHECK(input.is_complex(), "Expected a complex input tensor to FFT");
280-
const auto input_dim = input.dim();
281-
282242
Tensor x = resize_fft_input(input, dim, shape);
283-
x = at::view_as_real(x);
284-
285-
const int64_t transform_ndim = dim.size();
286243
const auto norm = norm_from_string(norm_str, forward);
287-
// _fft_with_size only supports 3 dimensions being transformed at a time.
288-
// This limit is inherited from cuFFT.
289-
constexpr int64_t max_signal_ndim = 3;
290-
291-
// Transform n dimensions, up to 3 at a time
292-
// TODO: rewrite _fft_with_size to transform more than 3 dimensions at once.
293-
for (int64_t i = 0; i < transform_ndim; i += max_signal_ndim) {
294-
const int64_t signal_ndim = std::min(transform_ndim - i, max_signal_ndim);
295-
DimVector source_dim(signal_ndim);
296-
DimVector dest_dim(signal_ndim);
297-
298-
for (int64_t j = 0; j < signal_ndim; ++j) {
299-
source_dim[j] = dim[i + j];
300-
dest_dim[j] = j + (input_dim - signal_ndim);
301-
}
302-
303-
// _fft operates on up-to the last 3 dims, so move selected dims to the end
304-
x = at::movedim(x, source_dim, dest_dim);
305-
306-
x = _fft(x, signal_ndim, /*complex_input=*/true, /*complex_output=*/true,
307-
/*inverse=*/!forward, /*signal_sizes=*/{}, /*normalization=*/norm,
308-
/*onesided=*/false);
309-
310-
// Move transform dims back to their original order
311-
x = at::movedim(x, dest_dim, source_dim);
312-
}
313-
314-
return at::view_as_complex(x);
244+
return at::_fft_c2c(x, dim, static_cast<int64_t>(norm), forward);
315245
}
316246

317-
}
247+
} // namespace (anonymous)
318248

319249
// torch.fft.fft, analogous to NumPy's numpy.fft.fft
320250
Tensor fft_fft(const Tensor& self, c10::optional<int64_t> n, int64_t dim,
@@ -370,44 +300,36 @@ Tensor fft_ifftn(const Tensor& self, c10::optional<IntArrayRef> s,
370300

371301
Tensor fft_rfftn(const Tensor& self, c10::optional<IntArrayRef> s,
372302
c10::optional<IntArrayRef> dim,
373-
c10::optional<std::string> norm) {
303+
c10::optional<std::string> norm_str) {
304+
TORCH_CHECK(!self.is_complex(), "rfftn expects a real-valued input tensor, but got ", self.scalar_type());
374305
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
375306
TORCH_CHECK(desc.shape.size() > 0, "rfftn must transform at least one axis");
376-
377-
const auto last_dim = desc.dim.back();
378-
const auto last_shape = desc.shape.back();
379-
desc.shape.pop_back();
380-
desc.dim.pop_back();
381-
382-
// rfft on last dim to get hermitian complex shape
383-
auto x = native::fft_rfft(self, last_shape, last_dim, norm);
384-
// Normal fft on remaining dims
385-
return fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/true);
307+
Tensor input = promote_tensor_fft(self, /*require_complex=*/false);
308+
Tensor x = resize_fft_input(input, desc.dim, desc.shape);
309+
const auto norm = norm_from_string(norm_str, /*forward=*/true);
310+
return at::_fft_r2c(x, desc.dim, static_cast<int64_t>(norm), /*onesided=*/true);
386311
}
387312

388313
Tensor fft_irfftn(const Tensor& self, c10::optional<IntArrayRef> s,
389314
c10::optional<IntArrayRef> dim,
390-
c10::optional<std::string> norm) {
315+
c10::optional<std::string> norm_str) {
391316
auto desc = canonicalize_fft_shape_and_dim_args(self, s, dim);
392317
TORCH_CHECK(desc.shape.size() > 0, "irfftn must transform at least one axis");
393318

394-
const auto last_dim = desc.dim.back();
395-
const auto last_shape = [&]() -> c10::optional<int64_t> {
396-
// If shape is defaulted in the last dimension,
397-
// pass nullopt to irfft and let it calculate the default size
319+
const auto last_dim_size = [&] {
320+
// Fixup default shape handling in the last dimension,
398321
if (!s.has_value() || (s->back() == -1)) {
399-
return c10::nullopt;
322+
const auto last_dim = desc.dim.back();
323+
return 2 * (self.sizes()[last_dim] - 1);
400324
}
401325
return desc.shape.back();
402326
}();
403-
desc.shape.pop_back();
404-
desc.dim.pop_back();
405-
406-
// Normal ifft for all but last dim
407-
Tensor x = promote_tensor_fft(self, /*require_complex=*/true);
408-
x = fftn_c2c(x, desc.shape, desc.dim, norm, /*forward=*/false);
409-
// Then 1d irfft on last dim to get real output
410-
return native::fft_irfft(x, last_shape, last_dim, norm);
327+
desc.shape.back() = last_dim_size / 2 + 1;
328+
329+
Tensor input = promote_tensor_fft(self, /*require_complex=*/true);
330+
Tensor x = resize_fft_input(input, desc.dim, desc.shape);
331+
const auto norm = norm_from_string(norm_str, /*forward=*/false);
332+
return at::_fft_c2r(x, desc.dim, static_cast<int64_t>(norm), last_dim_size);
411333
}
412334

413335
Tensor fft_fft2(const Tensor& self, c10::optional<IntArrayRef> s,

0 commit comments

Comments
 (0)