Skip to content

Commit a75af59

Browse files
committed
not for land, just testing on "quantized tensor: add support for advanced indexing"
Summary: Implements support for the indexing of quantized tensors with lists of dims, such as ``` xq_slice = xq[:, [0], :, :] ``` If helpful for reviewers, the things originally broken were, in order: 1. `computeDeviceType` did not handle `DispatchKey::QuantizedCPU` (fix: added) 2. quantization params were not present in `TensorIterator::set_output` so they can be used to properly create the quantized tensor (fix: created `TensorQuantizationOptions` and threaded it through the relevant places) 3. `index` kernel was not enabled for quantized dtypes (fix: enable it) Note: this PR only handles per-Tensor qparams. We don't expect to need this for per-channel qparams any time soon, ideally we can pay the eng cost for enabling that when it is needed. Test Plan: ``` python test/test_quantization.py TestQuantizedOps.test_advanced_indexing ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D25451651](https://our.internmc.facebook.com/intern/diff/D25451651) [ghstack-poisoned]
2 parents bc20313 + 8397a62 commit a75af59

File tree

34 files changed

+881
-685
lines changed

34 files changed

+881
-685
lines changed

aten/src/ATen/TensorIterator.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,10 +485,12 @@ void TensorIteratorBase::allocate_or_resize_outputs() {
485485
// At the moment, quantized kernels mostly handle output Tensor
486486
// construction manually, this path is an edge case. So, only support
487487
// the single input case for now.
488+
/*
488489
TORCH_INTERNAL_ASSERT(
489490
operands_.size() == num_outputs_ + 1,
490491
"Advanced indexing of quantized Tensors with multiple inputs is not "
491492
"supported yet.");
493+
*/
492494
// get the first input and copy its quantization parameters
493495
const auto& first_input = operands_[num_outputs_];
494496
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(first_input.tensor.is_quantized());

aten/src/ATen/cpu/vec256/vec256_bfloat16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) {
2525
static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) {
2626
__m256i lo = _mm256_castps_si256(a);
2727
__m256i hi = _mm256_castps_si256(b);
28-
__m256i nan = _mm256_set1_epi32(0x7fc0);
28+
__m256i nan = _mm256_set1_epi32(0xffff);
2929
__m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q));
3030
__m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q));
3131
__m256i ones = _mm256_set1_epi32(0x1);

aten/src/ATen/native/SpectralOps.cpp

Lines changed: 0 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@
1919

2020
namespace at { namespace native {
2121

22-
// Common code for all FFT functions
23-
static inline Tensor _fft(
24-
const Tensor &self, int64_t signal_ndim, bool complex_input,
25-
const bool complex_output, bool inverse, IntArrayRef signal_sizes,
26-
fft_norm_mode normalization, bool onesided);
27-
2822
namespace {
2923

3024
// Promote inputs to FFT functions
@@ -416,139 +410,6 @@ Tensor fft_ifftshift(const Tensor& x, c10::optional<IntArrayRef> dim_opt) {
416410
}
417411

418412

419-
// This is a pass-through wrapper function that does the size check and
420-
// inferences. The actual forward implementation function is called
421-
// at::_fft_with_size which dispatches to _fft_cufft (CUDA) or _fft_mkl (CPU).
422-
static inline Tensor _fft(const Tensor &self, const int64_t signal_ndim,
423-
const bool complex_input, const bool complex_output,
424-
const bool inverse, IntArrayRef signal_sizes,
425-
const fft_norm_mode normalization, const bool onesided) {
426-
427-
TORCH_CHECK(signal_ndim >= 1 && signal_ndim <= 3,
428-
"Expected signal_ndim to be 1, 2, or 3, but got signal_ndim=",
429-
signal_ndim);
430-
TORCH_CHECK(at::isFloatingType(self.scalar_type()),
431-
"Expected an input tensor of floating types, but got input=",
432-
self.toString(), self.sizes());
433-
434-
auto signal_tensor_ndim = signal_ndim + static_cast<int64_t>(complex_input); // add complex dim
435-
if (self.dim() < signal_tensor_ndim) {
436-
std::ostringstream ss;
437-
ss << "Given signal_ndim=" << signal_ndim << ", expected an input tensor "
438-
<< "of at least " << signal_tensor_ndim << "D";
439-
if (complex_input) {
440-
ss << " (complex input adds an extra dimension)";
441-
}
442-
ss << ", but got input=" << self.toString() << self.sizes();
443-
AT_ERROR(ss.str());
444-
}
445-
446-
auto self_shape = self.sizes();
447-
auto batch_ndim = self.dim() - signal_tensor_ndim;
448-
449-
Tensor input = self;
450-
// flatten the batch dims
451-
if (batch_ndim == 0) {
452-
// slightly faster path for non-batch mode
453-
input = input.unsqueeze(0);
454-
} else if (batch_ndim > 1) {
455-
std::vector<int64_t> flatten_input_shape(signal_tensor_ndim + 1);
456-
std::copy(self_shape.begin() + batch_ndim, self_shape.end(), flatten_input_shape.begin() + 1);
457-
flatten_input_shape[0] = -1;
458-
input = input.reshape(flatten_input_shape);
459-
460-
}
461-
462-
// now we assume that input is batched as [ B x signal_dims... ]
463-
464-
if (complex_input) {
465-
TORCH_CHECK(input.size(signal_ndim + 1) == 2,
466-
"Expected an input tensor with a last dimension of size 2 "
467-
"representing real + imaginary components, but got input ",
468-
self.toString(), self.sizes());
469-
}
470-
471-
// build signal_sizes and output_size
472-
TORCH_CHECK(signal_sizes.size() == 0 || static_cast<int64_t>(signal_sizes.size()) == signal_ndim,
473-
"Expected signal_sizes to be empty (default) or of signal_ndim=",
474-
signal_ndim, "D, but got signal_sizes=", signal_sizes);
475-
std::vector<int64_t> output_sizes(signal_ndim + 1 + static_cast<int64_t>(complex_output));
476-
output_sizes[0] = input.size(0); // batch size
477-
std::vector<int64_t> checked_signal_sizes(signal_ndim);
478-
for (int64_t i = 0; i < signal_ndim; i++) {
479-
int64_t input_size = input.size(i + 1);
480-
if (i == signal_ndim - 1 && onesided && complex_input && !complex_output) {
481-
// If last dim and complex-to-real onesided, input is only half of
482-
// signal, and we need to infer basing on signal_sizes, if given
483-
// See native/SpectralOpsUtils.h for detailed description.
484-
int64_t inferred_size;
485-
if (signal_sizes.size() > 0) {
486-
inferred_size = infer_ft_complex_to_real_onesided_size(input_size, signal_sizes[i]);
487-
} else {
488-
inferred_size = infer_ft_complex_to_real_onesided_size(input_size);
489-
}
490-
checked_signal_sizes[i] = inferred_size;
491-
output_sizes[i + 1] = inferred_size;
492-
} else {
493-
if (i == signal_ndim - 1 && onesided && !complex_input && complex_output) {
494-
// if last dim and real-to-complex onesided, output should be only
495-
// half of the signal, and we need to infer using input_size
496-
output_sizes[i + 1] = infer_ft_real_to_complex_onesided_size(input_size);
497-
} else {
498-
output_sizes[i + 1] = input_size;
499-
}
500-
checked_signal_sizes[i] = input_size;
501-
TORCH_CHECK(signal_sizes.size() == 0 || signal_sizes[i] == checked_signal_sizes[i],
502-
"Expected given signal_sizes=", signal_sizes," to have same "
503-
"shape with input at signal dimension ", i, ", but got "
504-
"signal_sizes=", signal_sizes, " and input=", self.toString(),
505-
self.sizes());
506-
}
507-
}
508-
if (complex_output) {
509-
output_sizes[signal_ndim + 1] = 2;
510-
}
511-
512-
Tensor output = at::_fft_with_size(input, signal_ndim, complex_input,
513-
complex_output, inverse,
514-
checked_signal_sizes,
515-
static_cast<int64_t>(normalization),
516-
onesided,
517-
output_sizes);
518-
519-
// unflatten the batch dims
520-
if (batch_ndim == 0) {
521-
// slightly faster path for non-batch mode
522-
output = output.squeeze(0);
523-
} else if (batch_ndim > 1) {
524-
auto output_ndim = self.dim() + static_cast<int64_t>(complex_output) - static_cast<int64_t>(complex_input);
525-
std::vector<int64_t> unflatten_output_shape(output_ndim);
526-
std::copy(self_shape.begin(), self_shape.begin() + batch_ndim, unflatten_output_shape.begin());
527-
std::copy(output_sizes.begin() + 1, output_sizes.end(), unflatten_output_shape.begin() + batch_ndim);
528-
output = output.reshape(unflatten_output_shape);
529-
}
530-
return output;
531-
}
532-
533-
// Wrapper to preserve the historic signature of _fft_with_size
534-
// NOTE: This is only used for torchscript backwards compatibility and the new
535-
// signature with normalization modes should be used in all other cases
536-
Tensor _fft_with_size(const Tensor& input, int64_t signal_ndim,
537-
bool complex_input, bool complex_output,
538-
bool inverse, IntArrayRef checked_signal_sizes,
539-
bool normalized, bool onesided,
540-
IntArrayRef output_sizes) {
541-
fft_norm_mode norm;
542-
if (normalized) {
543-
norm = fft_norm_mode::by_root_n;
544-
} else {
545-
norm = inverse ? fft_norm_mode::by_n : fft_norm_mode::none;
546-
}
547-
return at::_fft_with_size(
548-
input, signal_ndim, complex_input, complex_output, inverse,
549-
checked_signal_sizes, static_cast<int64_t>(norm), onesided, output_sizes);
550-
}
551-
552413
// We call the following methods via CUDA hooks because they are really only
553414
// valid when CUDA is available. See native/cuda/CuFFTPlanCache.h for more details.
554415
int64_t _cufft_get_plan_cache_max_size(int64_t device_index) {
@@ -567,36 +428,6 @@ void _cufft_clear_plan_cache(int64_t device_index) {
567428
detail::getCUDAHooks().cuFFTClearPlanCache(device_index);
568429
}
569430

570-
static Tensor fft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
571-
return _fft(self, signal_ndim, /* complex_input */ true,
572-
/* complex_output */ true, /* inverse */ false, {},
573-
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
574-
/* onesided */ false);
575-
}
576-
577-
static Tensor ifft(const Tensor& self, const int64_t signal_ndim, const bool normalized) {
578-
return _fft(self, signal_ndim, /* complex_input */ true,
579-
/* complex_output */ true, /* inverse */ true, {},
580-
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
581-
/* onesided */ false);
582-
}
583-
584-
static Tensor rfft(const Tensor& self, const int64_t signal_ndim, const bool normalized,
585-
const bool onesided) {
586-
return _fft(self, signal_ndim, /* complex_input */ false,
587-
/* complex_output */ true, /* inverse */ false, {},
588-
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::none,
589-
onesided);
590-
}
591-
592-
static Tensor irfft(const Tensor& self, const int64_t signal_ndim, const bool normalized,
593-
const bool onesided, IntArrayRef signal_sizes) {
594-
return _fft(self, signal_ndim, /* complex_input */ true,
595-
/* complex_output */ false, /* inverse */ true, signal_sizes,
596-
normalized ? fft_norm_mode::by_root_n : fft_norm_mode::by_n,
597-
onesided);
598-
}
599-
600431
template <typename Stream, typename T>
601432
static Stream& write_opt(Stream& SS, const optional<T>& value) {
602433
if (value) {

aten/src/ATen/native/TensorProperties.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <ATen/ATen.h>
22
#include <ATen/NativeFunctions.h>
3-
#include <ATen/WrapDimUtils.h>
43
#include <ATen/detail/CUDAHooksInterface.h>
54
#include <ATen/NamedTensorUtils.h>
65
#include <torch/library.h>
@@ -14,15 +13,11 @@ bool is_same_size(const Tensor& self, const Tensor& other) {
1413
}
1514

1615
int64_t size(const Tensor& self, int64_t dim) {
17-
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
18-
dim = maybe_wrap_dim(dim, self.dim(), false);
19-
return self.sizes()[dim];
16+
return self.size(dim);
2017
}
2118

2219
int64_t stride(const Tensor& self, int64_t dim) {
23-
// false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
24-
dim = maybe_wrap_dim(dim, self.dim(), false);
25-
return self.strides()[dim];
20+
return self.stride(dim);
2621
}
2722

2823
int64_t size(const Tensor& self, Dimname dim) {

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ static void sign_kernel(TensorIterator& iter){
277277
[=](scalar_t a) -> scalar_t { return (0 < a) - (a < 0); },
278278
[=](Vec256<scalar_t> self_vec){
279279

280-
// Comparision operators returns bitmask.
280+
// Comparison operators returns bitmask.
281281
auto left = Vec256<scalar_t>::blendv(zero_vec, one_vec, zero_vec < self_vec);
282282
auto right = Vec256<scalar_t>::blendv(zero_vec, one_vec, self_vec < zero_vec);
283283

aten/src/ATen/native/cuda/SpectralOps.cu

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -589,112 +589,5 @@ Tensor _fft_c2c_cufft(const Tensor& self, IntArrayRef dim, int64_t normalization
589589
return output;
590590
}
591591
592-
// cuFFT
593-
// Currently not utilizing multi GPUs so this can be potentially sped up.
594-
Tensor _fft_cufft(const Tensor& self, int64_t signal_ndim,
595-
bool complex_input, bool complex_output, bool inverse,
596-
IntArrayRef checked_signal_sizes, int64_t normalization, bool onesided,
597-
IntArrayRef output_sizes) {
598-
599-
CuFFTParamsLRUCache& plan_cache = cufft_get_plan_cache(self.device().index());
600-
601-
Tensor input = self;
602-
const auto fft_type = GetCuFFTTransformType(complex_input, complex_output);
603-
604-
if (complex_input) {
605-
TORCH_CHECK(input.size(-1) == 2, "Expected a complex (size 2) last dimension");
606-
}
607-
608-
609-
// Slice when twosided complex-to-real. This is not always needed because we
610-
// calculate the inembed. But it will benefit us in certain cases where we
611-
// clone the input tensor.
612-
//
613-
// See NOTE [ cuFFT Embedded Strides ].
614-
// See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
615-
if (fft_type == CuFFTTransformType::C2R && !onesided) {
616-
auto onesided_size = infer_ft_real_to_complex_onesided_size(checked_signal_sizes[signal_ndim - 1]);
617-
input = input.narrow(signal_ndim, 0, onesided_size);
618-
}
619-
620-
// cuFFT requires input and output data pointers to complex type aligned.
621-
// Our newly allocated output tensor is always 512 bytes aligned so it is fine
622-
// (see kRoundSmall and kRoundLarge in THCCachingAllocator.cpp), but we do
623-
// need to check input tensor to make sure that it is not unaligned, e.g.,
624-
// from a slicing.
625-
bool must_clone = false;
626-
auto complex_size_bytes = 2 * input.element_size();
627-
if (reinterpret_cast<std::uintptr_t>(input.data_ptr()) % complex_size_bytes != 0) {
628-
must_clone = true;
629-
}
630-
631-
if (complex_input) {
632-
auto strides = input.strides();
633-
// Real/imag dimension must be like complex type.
634-
must_clone |= strides.back() != 1;
635-
// Strides of other dimensions needs to be aligned when viewed as complex
636-
// type, i.e., multiples of 2.
637-
must_clone |= std::any_of(strides.begin(), strides.end() - 1,
638-
[&](int64_t stride) { return stride % 2 != 0; });
639-
640-
// Complex to real FFTs may overwrite the input buffer (gh-34551)
641-
must_clone |= !complex_output;
642-
}
643-
644-
if (must_clone) {
645-
input = input.clone(MemoryFormat::Contiguous);
646-
}
647-
648-
// Now that we have done error check and data_ptr checks, we delegate all
649-
// further cuFFT parameter computation and plan creation to the helper class
650-
// CuFFTConfig in CuFFTPlanCache.h.
651-
652-
// If plan caching is enabled, we check the cache. Note that this accesses
653-
// plan_cache.max_size() and thus makes this function less functional.
654-
// However, integrating additional arguments into the "public" level c++ APIs,
655-
// e.g., irfft, is difficult as we have a long call sequence looking like
656-
// irfft --> _fft --> _fft_with_size --dispatching-to-> _fft_cufft
657-
658-
DimVector in_strides(signal_ndim + 1);
659-
auto input_strides = input.strides();
660-
for (int64_t i = signal_ndim; i >= 0; --i) {
661-
in_strides[i] = complex_input ? input_strides[i] / 2 : input_strides[i];
662-
}
663-
664-
DimVector out_strides(signal_ndim + 1);
665-
out_strides[signal_ndim] = 1;
666-
if (fft_type == CuFFTTransformType::R2C && onesided) {
667-
out_strides[signal_ndim - 1] = checked_signal_sizes[signal_ndim - 1] / 2 + 1;
668-
} else {
669-
out_strides[signal_ndim - 1] = checked_signal_sizes[signal_ndim - 1];
670-
}
671-
for (int64_t i = signal_ndim - 2; i >= 0; --i) {
672-
out_strides[i] = out_strides[i + 1] * checked_signal_sizes[i];
673-
}
674-
675-
DimVector full_sizes(signal_ndim + 1);
676-
full_sizes[0] = self.size(0);
677-
std::copy(checked_signal_sizes.begin(), checked_signal_sizes.end(), full_sizes.begin() + 1);
678-
CuFFTParams Params(in_strides, out_strides, full_sizes, fft_type,
679-
c10::toValueType(input.scalar_type()));
680-
681-
// This read is not locked for perf reason. Shouldn't matter too much because
682-
// we check again after acquiring the lock.
683-
if (plan_cache.max_size() > 0) {
684-
std::lock_guard<std::mutex> guard(plan_cache.mutex);
685-
if (plan_cache.max_size() > 0) { // check again after acquiring the lock
686-
const CuFFTConfig &config = plan_cache.lookup(Params);
687-
return _run_cufft(config, input, signal_ndim, complex_input,
688-
complex_output, inverse, checked_signal_sizes,
689-
static_cast<fft_norm_mode>(normalization),
690-
onesided, output_sizes, must_clone);
691-
}
692-
}
693-
CuFFTConfig config(Params);
694-
return _run_cufft(config, input, signal_ndim, complex_input,
695-
complex_output, inverse, checked_signal_sizes,
696-
static_cast<fft_norm_mode>(normalization),
697-
onesided, output_sizes, must_clone);
698-
}
699592
700593
}} // at::native

0 commit comments

Comments
 (0)