Skip to content

Commit a541569

Browse files
committed
Update on "[PyTorch] Add c10::hash<c10::ArrayRef"
Just moved the vector implementation to ArrayRef and re-implemented the former using the latter. Differential Revision: [D30647666](https://our.internmc.facebook.com/intern/diff/D30647666/) [ghstack-poisoned]
2 parents e02cfb1 + a54ded9 commit a541569

File tree

13 files changed

+279
-194
lines changed

13 files changed

+279
-194
lines changed

aten/src/ATen/Dispatch.h

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,13 @@ inline constexpr bool should_include_kernel_dtype(
8080
#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
8181
#endif // defined(__CUDACC__) && CUDA_VERSION <= 10100
8282

83+
#if defined __cpp_if_constexpr
8384
#define AT_QINT_PRIVATE_CASE_TYPE( \
84-
enum_type, type, underlying_enum, underlying_type, ...) \
85+
NAME, enum_type, type, underlying_enum, underlying_type, ...) \
8586
case enum_type: { \
87+
if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
88+
AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
89+
} \
8690
using scalar_t = type; \
8791
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
8892
scalar_t::underlying; \
@@ -93,10 +97,57 @@ inline constexpr bool should_include_kernel_dtype(
9397
/* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
9498
return __VA_ARGS__(); \
9599
}
100+
#else
101+
#define AT_QINT_PRIVATE_CASE_TYPE( \
102+
NAME, enum_type, type, underlying_enum, underlying_type, ...) \
103+
case enum_type: { \
104+
at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
105+
[] { \
106+
AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
107+
} \
108+
); \
109+
using scalar_t = type; \
110+
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
111+
scalar_t::underlying; \
112+
const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
113+
const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
114+
toUnderlying(enum_type); \
115+
(void)SCALAR_TYPE; /* Suppress unused-var compiler warning */ \
116+
/* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
117+
return __VA_ARGS__(); \
118+
}
119+
#endif
96120

121+
#if defined __cpp_if_constexpr
122+
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
123+
NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
124+
case enum_type: { \
125+
if constexpr (!at::should_include_kernel_dtype(NAME, enum_type)) { \
126+
AT_ERROR("dtype '", toString(enum_type), "' not selected for kernel tag ", #NAME); \
127+
} \
128+
using scalar_t = type; \
129+
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
130+
scalar_t::underlying; \
131+
const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
132+
const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
133+
toUnderlying(enum_type); \
134+
int bit_width = bitwidth; \
135+
int64_t quant_min = qmin; \
136+
int64_t quant_max = qmax; \
137+
(void)bit_width; /* Suppress unused variable warning */ \
138+
(void)quant_min; /* Suppress unused variable warning */ \
139+
(void)quant_max; /* Suppress unused variable warning */ \
140+
return __VA_ARGS__(); \
141+
}
142+
#else
97143
#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
98-
enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
144+
NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
99145
case enum_type: { \
146+
at::guts::if_constexpr<(!at::should_include_kernel_dtype(NAME, enum_type))>( \
147+
[] { \
148+
AT_ERROR("dtype '" #enum_type "' not selected for kernel tag " #NAME); \
149+
} \
150+
); \
100151
using scalar_t = type; \
101152
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
102153
scalar_t::underlying; \
@@ -111,6 +162,7 @@ inline constexpr bool should_include_kernel_dtype(
111162
(void)quant_max; /* Suppress unused variable warning */ \
112163
return __VA_ARGS__(); \
113164
}
165+
#endif
114166

115167
namespace detail {
116168

@@ -449,11 +501,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
449501
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
450502
switch (_st) { \
451503
AT_QINT_PRIVATE_CASE_TYPE( \
452-
at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
504+
NAME, at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
453505
AT_QINT_PRIVATE_CASE_TYPE( \
454-
at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
506+
NAME, at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
455507
AT_QINT_PRIVATE_CASE_TYPE( \
456-
at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
508+
NAME, at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
457509
default: \
458510
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
459511
} \
@@ -467,13 +519,13 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
467519
RECORD_KERNEL_FUNCTION_DTYPE(NAME, _st); \
468520
switch (_st) { \
469521
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
470-
at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
522+
NAME, at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
471523
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
472-
at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
524+
NAME, at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
473525
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
474-
at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
526+
NAME, at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
475527
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
476-
at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
528+
NAME, at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
477529
default: \
478530
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
479531
} \

aten/src/ATen/core/TensorBase.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -755,12 +755,6 @@ class TORCH_API TensorBase {
755755
TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
756756
};
757757

758-
// For "multiple ... operators specified" warnings, closing brace of class
759-
// declaration must be included between pragma push & pop
760-
#ifdef _MSC_VER
761-
#pragma warning( pop )
762-
#endif
763-
764758
inline int64_t get_device(const TensorBase& self) {
765759
return self.get_device();
766760
}

aten/src/ATen/native/DistributionTemplates.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ namespace templates {
2424
//
2525
// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it becomes 65504
2626
// and violates the requirement that random value must be less than `to`. To resolve this issue `update_from` and `update_to`
27-
// moves `from` to the left and `to` to the right to the next closest value that won't go outside [from, to) after casting to
27+
// moves `from` to the right and `to` to the left to the next closest value that won't go outside [from, to) after casting to
2828
// the target dtype. For `to` = 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
2929
// available number for torch::half dtype.
3030
template<typename scalar_t>

aten/src/ATen/native/Loss.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,32 @@ DEFINE_DISPATCH(mse_stub);
3030
DEFINE_DISPATCH(mse_backward_stub);
3131

3232
Tensor cosine_embedding_loss(const Tensor& input1, const Tensor& input2, const Tensor& target, double margin, int64_t reduction) {
33+
auto targ_dim = target.dim();
3334
TORCH_CHECK(
34-
target.dim() == 1,
35-
"1D target tensor expected, multi-target not supported");
35+
targ_dim == 1 || targ_dim == 0,
36+
"0D or 1D target tensor expected, multi-target not supported");
37+
38+
if (targ_dim == 1) {
39+
TORCH_CHECK(
40+
input1.dim() == 2,
41+
"1D target tensor expects 2D input tensors, but found inputs with sizes ",
42+
input1.sizes(),
43+
" and ",
44+
input2.sizes(),
45+
".");
46+
} else {
47+
TORCH_CHECK(
48+
input1.dim() == 1,
49+
"0D target tensor expects 1D input tensors, but found inputs with sizes ",
50+
input1.sizes(),
51+
" and ",
52+
input2.sizes(),
53+
".");
54+
}
3655

37-
auto prod_sum = (input1 * input2).sum(1);
38-
auto mag_square1 = (input1 * input1).sum(1) + EPSILON;
39-
auto mag_square2 = (input2 * input2).sum(1) + EPSILON;
56+
auto prod_sum = (input1 * input2).sum(targ_dim);
57+
auto mag_square1 = (input1 * input1).sum(targ_dim) + EPSILON;
58+
auto mag_square2 = (input2 * input2).sum(targ_dim) + EPSILON;
4059
auto denom = (mag_square1 * mag_square2).sqrt_();
4160
auto cos = prod_sum / denom;
4261

aten/src/ATen/native/quantized/affine_quantizer.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Tensor& quantize_tensor_per_tensor_affine(
107107
Tensor& qtensor,
108108
double scale,
109109
int64_t zero_point) {
110-
static const std::string fn_name = "quantize_tensor_per_tensor_affine";
110+
static constexpr auto fn_name = "quantize_tensor_per_tensor_affine";
111111

112112
checkRoundingMode(fn_name);
113113
checkFloatTensor(fn_name, rtensor);
@@ -138,7 +138,7 @@ Tensor& quantize_tensor_per_channel_affine(
138138
Tensor scales,
139139
Tensor zero_points,
140140
int64_t axis) {
141-
static const std::string fn_name = "quantize_tensor_per_channel_affine";
141+
static constexpr auto fn_name = "quantize_tensor_per_channel_affine";
142142

143143
checkRoundingMode(fn_name);
144144
checkFloatTensor(fn_name, rtensor);
@@ -178,7 +178,7 @@ Tensor& quantize_tensor_per_channel_float_qparams(
178178
Tensor scales,
179179
Tensor zero_points,
180180
int64_t axis) {
181-
static const std::string fn_name =
181+
static constexpr auto fn_name =
182182
"quantize_tensor_per_channel_float_qparams";
183183

184184
checkRoundingMode(fn_name);
@@ -216,7 +216,7 @@ Tensor& dequantize_tensor_per_tensor_affine(
216216
Tensor& rtensor,
217217
double scale,
218218
int64_t zero_point) {
219-
static const std::string fn_name = "dequantize_tensor_per_tensor_affine";
219+
static constexpr auto fn_name = "dequantize_tensor_per_tensor_affine";
220220
checkFloatTensor(fn_name, rtensor);
221221
checkSameDevice(fn_name, rtensor, qtensor);
222222
checkSameSize(fn_name, qtensor, rtensor);
@@ -243,7 +243,7 @@ Tensor& dequantize_tensor_per_channel_affine(
243243
Tensor scales,
244244
Tensor zero_points,
245245
int64_t axis) {
246-
static const std::string fn_name = "dequantize_tensor_per_channel_affine";
246+
static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
247247

248248
checkFloatTensor(fn_name, rtensor);
249249
checkSameDevice(fn_name, rtensor, qtensor);
@@ -282,7 +282,7 @@ Tensor& dequantize_tensor_per_channel_float_qparams(
282282
Tensor scales,
283283
Tensor zero_points,
284284
int64_t axis) {
285-
static const std::string fn_name = "dequantize_tensor_per_channel_affine";
285+
static constexpr auto fn_name = "dequantize_tensor_per_channel_affine";
286286

287287
checkFloatTensor(fn_name, rtensor);
288288
checkSameDevice(fn_name, rtensor, qtensor);

0 commit comments

Comments
 (0)