Skip to content

Commit fe87087

Browse files
committed
Workaround for fbgemm::FindMinMax
Signed-off-by: Yuanyuan Chen <[email protected]>
1 parent 88c52d0 commit fe87087

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

aten/src/ATen/native/QuantizedLinear.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/Functions.h>
1212
#include <ATen/NativeFunctions.h>
1313
#else
14+
#include <ATen/ops/aminmax.h>
1415
#include <ATen/ops/empty.h>
1516
#include <ATen/ops/empty_like_native.h>
1617
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
@@ -81,11 +82,19 @@ Tensor fbgemm_linear_int8_weight_fp32_activation(
8182
// Calculate statistics for quantization of the input Tensor
8283
float x_min = std::numeric_limits<float>::quiet_NaN();
8384
float x_max = std::numeric_limits<float>::quiet_NaN();
85+
#if defined(__AVX__)
8486
fbgemm::FindMinMax(
8587
/*m=*/input_ptr,
8688
/*min=*/&x_min,
8789
/*max=*/&x_max,
8890
/*len=*/input.numel());
91+
#else
92+
if (input.numel() > 0) {
93+
auto [t_min, t_max] = at::aminmax(input);
94+
x_max = t_max.item<float>();
95+
x_min = t_min.item<float>();
96+
}
97+
#endif
8998

9099
// Input tensor is quantized as 8-bit unsigned values
91100
constexpr int kPrecision = 8;
@@ -237,11 +246,19 @@ std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
237246
// Calculate weight statistics
238247
float w_min = std::numeric_limits<float>::quiet_NaN();
239248
float w_max = std::numeric_limits<float>::quiet_NaN();
249+
#if defined(__AVX__)
240250
fbgemm::FindMinMax(
241251
/*m=*/weight_contig.data_ptr<float>(),
242252
/*min=*/&w_min,
243253
/*max=*/&w_max,
244254
/*len=*/weight_contig.numel());
255+
#else
256+
if (weight_contig.numel() > 0) {
257+
auto [t_min, t_max] = at::aminmax(weight_contig);
258+
w_max = t_max.item<float>();
259+
w_min = t_min.item<float>();
260+
}
261+
#endif
245262

246263
// Choose parameters for quantizing the weight as 8-bit signed integer
247264
constexpr bool kIsSigned = true;

aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#ifndef AT_PER_OPERATOR_HEADERS
1717
#include <ATen/Functions.h>
1818
#else
19+
#include <ATen/ops/aminmax.h>
1920
#include <ATen/ops/dequantize.h> // for dequantize
2021
#include <ATen/ops/quantize_per_tensor.h>
2122
#endif
@@ -29,12 +30,20 @@ at::Tensor PackedConvWeight<kSpatialDim>::apply_dynamic(
2930
TORCH_CHECK(
3031
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
3132

32-
float x_min, x_max;
33+
float x_min = std::numeric_limits<float>::quiet_NaN(), x_max = std::numeric_limits<float>::quiet_NaN();
34+
#if defined(__AVX__)
3335
fbgemm::FindMinMax(
3436
/*m=*/input.data_ptr<float>(),
3537
/*min=*/&x_min,
3638
/*max=*/&x_max,
3739
/*len=*/input.numel());
40+
#else
41+
if (input.numel() > 0) {
42+
auto [t_min, t_max] = at::aminmax(input);
43+
x_max = t_max.item<float>();
44+
x_min = t_min.item<float>();
45+
}
46+
#endif
3847

3948
// Input tensor is quantized as 8-bit unsigned values
4049
static constexpr int precision = 8;

aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,19 @@ at::Tensor PackedLinearWeight::apply_dynamic_impl(
6969

7070
// Calculate statistics for quantization of the input Tensor
7171
float x_min = std::numeric_limits<float>::quiet_NaN(), x_max = std::numeric_limits<float>::quiet_NaN();
72+
#if defined(__AVX__)
7273
fbgemm::FindMinMax(
7374
/*m=*/input_ptr,
7475
/*min=*/&x_min,
7576
/*max=*/&x_max,
7677
/*len=*/input.numel());
78+
#else
79+
if (input_contig.numel() > 0) {
80+
auto [t_min, t_max] = at::aminmax(input_contig);
81+
x_max = t_max.item<float>();
82+
x_min = t_min.item<float>();
83+
}
84+
#endif
7785

7886
// Input tensor is quantized as 8-bit unsigned values
7987
static constexpr int precision = 8;
@@ -512,7 +520,7 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
512520
x.init(input_desc, input_contig.data_ptr());
513521
// Find quantization parameters
514522
float x_max = 0, x_min = 0;
515-
#ifdef USE_FBGEMM
523+
#if defined(USE_FBGEMM) && defined(__AVX__)
516524
// Use FBGEMM's FindMinMax if available since it's faster
517525
fbgemm::FindMinMax(
518526
/*m=*/input_contig.data_ptr<float>(),
@@ -738,7 +746,7 @@ at::Tensor PackedLinearWeightsACL::apply_dynamic_impl(
738746
// Find quantization parameters
739747
float x_max = 0, x_min = 0;
740748

741-
#ifdef USE_FBGEMM
749+
#if defined(USE_FBGEMM) && defined(__AVX__)
742750
// Use FBGEMM's FindMinMax if available since it's faster
743751
fbgemm::FindMinMax(
744752
/*m=*/input_contig.data_ptr<float>(),

0 commit comments

Comments
 (0)