@@ -80,9 +80,13 @@ inline constexpr bool should_include_kernel_dtype(
8080#define C10_UNUSED_DISPATCH_CUDA_WORKAROUND C10_UNUSED
8181#endif // defined(__CUDACC__) && CUDA_VERSION <= 10100
8282
83+ #if defined __cpp_if_constexpr
8384#define AT_QINT_PRIVATE_CASE_TYPE ( \
84- enum_type, type, underlying_enum, underlying_type, ...) \
85+ NAME, enum_type, type, underlying_enum, underlying_type, ...) \
8586 case enum_type: { \
87+ if constexpr (!at::should_include_kernel_dtype (NAME, enum_type)) { \
88+ AT_ERROR (" dtype '" , toString (enum_type), " ' not selected for kernel tag " , #NAME); \
89+ } \
8690 using scalar_t = type; \
8791 using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
8892 scalar_t ::underlying; \
@@ -93,10 +97,57 @@ inline constexpr bool should_include_kernel_dtype(
9397 /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
9498 return __VA_ARGS__ (); \
9599 }
100+ #else
101+ #define AT_QINT_PRIVATE_CASE_TYPE ( \
102+ NAME, enum_type, type, underlying_enum, underlying_type, ...) \
103+ case enum_type: { \
104+ at::guts::if_constexpr<(!at::should_include_kernel_dtype (NAME, enum_type))>( \
105+ [] { \
106+ AT_ERROR (" dtype '" #enum_type " ' not selected for kernel tag " #NAME); \
107+ } \
108+ ); \
109+ using scalar_t = type; \
110+ using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
111+ scalar_t ::underlying; \
112+ const auto & SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
113+ const auto & UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
114+ toUnderlying (enum_type); \
115+ (void )SCALAR_TYPE; /* Suppress unused-var compiler warning */ \
116+ /* TODO: Use [[maybe-unused]] when C++17 becomes the standard */ \
117+ return __VA_ARGS__ (); \
118+ }
119+ #endif
96120
121+ #if defined __cpp_if_constexpr
122+ #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE ( \
123+ NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
124+ case enum_type: { \
125+ if constexpr (!at::should_include_kernel_dtype (NAME, enum_type)) { \
126+ AT_ERROR (" dtype '" , toString (enum_type), " ' not selected for kernel tag " , #NAME); \
127+ } \
128+ using scalar_t = type; \
129+ using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
130+ scalar_t ::underlying; \
131+ const auto & SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
132+ const auto & UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
133+ toUnderlying (enum_type); \
134+ int bit_width = bitwidth; \
135+ int64_t quant_min = qmin; \
136+ int64_t quant_max = qmax; \
137+ (void )bit_width; /* Suppress unused variable warning */ \
138+ (void )quant_min; /* Suppress unused variable warning */ \
139+ (void )quant_max; /* Suppress unused variable warning */ \
140+ return __VA_ARGS__ (); \
141+ }
142+ #else
97143#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE ( \
98- enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
144+ NAME, enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
99145 case enum_type: { \
146+ at::guts::if_constexpr<(!at::should_include_kernel_dtype (NAME, enum_type))>( \
147+ [] { \
148+ AT_ERROR (" dtype '" #enum_type " ' not selected for kernel tag " #NAME); \
149+ } \
150+ ); \
100151 using scalar_t = type; \
101152 using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
102153 scalar_t ::underlying; \
@@ -111,6 +162,7 @@ inline constexpr bool should_include_kernel_dtype(
111162 (void )quant_max; /* Suppress unused variable warning */ \
112163 return __VA_ARGS__ (); \
113164 }
165+ #endif
114166
115167namespace detail {
116168
@@ -449,11 +501,11 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
449501 RECORD_KERNEL_FUNCTION_DTYPE (NAME, _st); \
450502 switch (_st) { \
451503 AT_QINT_PRIVATE_CASE_TYPE ( \
452- at::kQInt8 , at::qint8, at::kChar , int8_t , __VA_ARGS__) \
504+ NAME, at::kQInt8 , at::qint8, at::kChar , int8_t , __VA_ARGS__) \
453505 AT_QINT_PRIVATE_CASE_TYPE ( \
454- at::kQUInt8 , at::quint8, at::kByte , uint8_t , __VA_ARGS__) \
506+ NAME, at::kQUInt8 , at::quint8, at::kByte , uint8_t , __VA_ARGS__) \
455507 AT_QINT_PRIVATE_CASE_TYPE ( \
456- at::kQInt32 , at::qint32, at::kInt , int , __VA_ARGS__) \
508+ NAME, at::kQInt32 , at::qint32, at::kInt , int , __VA_ARGS__) \
457509 default : \
458510 AT_ERROR (#NAME, " not implemented for '" , toString (TYPE), " '" ); \
459511 } \
@@ -467,13 +519,13 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
467519 RECORD_KERNEL_FUNCTION_DTYPE (NAME, _st); \
468520 switch (_st) { \
469521 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE ( \
470- at::kQInt8 , at::qint8, int8_t , CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
522+ NAME, at::kQInt8 , at::qint8, int8_t , CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
471523 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE ( \
472- at::kQUInt8 , at::quint8, uint8_t , CHAR_BIT, 0 , UCHAR_MAX, __VA_ARGS__) \
524+ NAME, at::kQUInt8 , at::quint8, uint8_t , CHAR_BIT, 0 , UCHAR_MAX, __VA_ARGS__) \
473525 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE ( \
474- at::kQInt32 , at::qint32, int , CHAR_BIT * sizeof (int ), INT_MIN, INT_MAX, __VA_ARGS__) \
526+ NAME, at::kQInt32 , at::qint32, int , CHAR_BIT * sizeof (int ), INT_MIN, INT_MAX, __VA_ARGS__) \
475527 AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE ( \
476- at::kQUInt4x2 , at::quint4x2, uint8_t , 4 , 0 , 15 , __VA_ARGS__) \
528+ NAME, at::kQUInt4x2 , at::quint4x2, uint8_t , 4 , 0 , 15 , __VA_ARGS__) \
477529 default : \
478530 AT_ERROR (#NAME, " not implemented for '" , toString (TYPE), " '" ); \
479531 } \
0 commit comments