Skip to content

Commit 7243a5d

Browse files
Implemented group level static quantization for s8s8s32of32|bf16 APIs
Details: - Group quantization is technique to improve accuracy where scale factors to quantize inputs and weights varies at group level instead of per channel and per tensor level. - Added new bench files to test GEMM with symmetric static quantization. - Added new get_size and reorder functions to account for storing sum of col-values separately per group. - Added new framework, kernels to support the same. - The scalefactors could be of type float or bf16. AMD-Internal:[SWLCSG-3274] Change-Id: I3e69ecd56faa2679a4f084031d35ffb76556230f
1 parent 9977055 commit 7243a5d

25 files changed

+41770
-52
lines changed

addon/aocl_gemm/aocl_gemm_interface_apis.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(s8s8s32os32);
5656
AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s4s32os32);
5757
AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16s4f32of32);
5858

59+
// Returns the size of buffer in bytes required for the reordered matrix.
60+
#define AOCL_GEMM_GET_REORDER_BUF_SIZE_SYM_QUANT(LP_SFX) \
61+
BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \
62+
( \
63+
const char order, \
64+
const char trans, \
65+
const char mat_type, \
66+
const dim_t k, \
67+
const dim_t n, \
68+
AOCL_SYMM_STAT_QUANT* meta_data \
69+
) \
70+
71+
AOCL_GEMM_GET_REORDER_BUF_SIZE_SYM_QUANT(s8s8s32os32_sym_quant);
72+
5973
// Performs reordering of input matrix. Reordering is the process of packing
6074
// the entire matrix upfront, so that the benefits of packed matrix is obtained
6175
// without incurring the packing costs during matmul computation.
@@ -80,6 +94,22 @@ AOCL_GEMM_REORDER(int8_t,s8s8s32os32);
8094
AOCL_GEMM_REORDER(int8_t,u8s4s32os32);
8195
AOCL_GEMM_REORDER(int8_t, bf16s4f32of32);
8296

97+
#define AOCL_GEMM_REORDER_SYM_QUANT(B_type,LP_SFX) \
98+
BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \
99+
( \
100+
const char order, \
101+
const char trans, \
102+
const char mat_type, \
103+
const B_type* input_buf_addr, \
104+
B_type* reorder_buf_addr, \
105+
const dim_t k, \
106+
const dim_t n, \
107+
const dim_t ldb, \
108+
AOCL_SYMM_STAT_QUANT* meta_data \
109+
) \
110+
111+
AOCL_GEMM_REORDER_SYM_QUANT(int8_t,s8s8s32os32_sym_quant);
112+
83113
#define AOCL_GEMM_REORDER_MXP(A_type,B_type,LP_SFX) \
84114
BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \
85115
( \
@@ -145,6 +175,10 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16);
145175
AOCL_GEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32);
146176
AOCL_GEMM_MATMUL(int8_t,int8_t,uint8_t,int32_t,s8s8s32ou8);
147177

178+
// Symmetric static quantization GEMM API
179+
AOCL_GEMM_MATMUL(int8_t,int8_t,float,int32_t,s8s8s32of32_sym_quant);
180+
AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16_sym_quant);
181+
148182
AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16);
149183
AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32);
150184
AOCL_GEMM_MATMUL(bfloat16, int8_t, float, float, bf16s4f32of32);

addon/aocl_gemm/aocl_gemm_post_ops.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,16 @@ typedef struct
125125
{
126126
void* zero_point;
127127
//len should be one which is one or n i.e., one zp
128-
//per tensor or one zp per channel respectively
128+
//per tensor or one zp per channel respectively
129129
dim_t zero_point_len;
130+
AOCL_PARAMS_STORAGE_TYPES zero_point_type;
130131
} aocl_pre_op_zp;
131132

132133
typedef struct
133134
{
134135
void* scale_factor;
135136
//len should be one which is one or n i.e., one sf
136-
//per tensor or one sf per channel respectively
137+
//per tensor or one sf per channel respectively
137138
dim_t scale_factor_len;
138139
AOCL_PARAMS_STORAGE_TYPES scale_factor_type;
139140
} aocl_pre_op_sf;
@@ -146,6 +147,21 @@ typedef struct
146147
dim_t group_size;
147148
} aocl_pre_op;
148149

150+
typedef struct
151+
{
152+
dim_t group_size;
153+
dim_t seq_length;
154+
aocl_pre_op_sf *a_scl;
155+
aocl_pre_op_sf *b_scl;
156+
aocl_pre_op_zp *a_zp;
157+
aocl_pre_op_zp *b_zp;
158+
} aocl_group_post_op;
159+
160+
typedef struct
161+
{
162+
dim_t group_size;
163+
} AOCL_SYMM_STAT_QUANT;
164+
149165
typedef struct
150166
{
151167
aocl_post_op_sum* sum; // Multiple scale/sum allowed.
@@ -164,6 +180,7 @@ typedef struct
164180
//Pass pre-op structure also through post-ops
165181
aocl_pre_op *pre_ops;
166182

183+
aocl_group_post_op *post_op_grp;
167184
// To keep track of eltwise operations.
168185
dim_t num_eltwise;
169186

addon/aocl_gemm/aocl_gemm_s8s8s32obf16.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,4 @@ AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16)
269269
err_hndl:;
270270
LPGEMM_STOP_LOGGER();
271271
}
272+
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/*
2+
3+
BLIS
4+
An object-based framework for developing high-performance BLAS-like
5+
libraries.
6+
7+
Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved.
8+
9+
Redistribution and use in source and binary forms, with or without
10+
modification, are permitted provided that the following conditions are
11+
met:
12+
- Redistributions of source code must retain the above copyright
13+
notice, this list of conditions and the following disclaimer.
14+
- Redistributions in binary form must reproduce the above copyright
15+
notice, this list of conditions and the following disclaimer in the
16+
documentation and/or other materials provided with the distribution.
17+
- Neither the name(s) of the copyright holder(s) nor the names of its
18+
contributors may be used to endorse or promote products derived
19+
from this software without specific prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22+
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23+
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24+
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32+
33+
*/
34+
35+
#include "blis.h"
36+
#include "aocl_gemm_interface_apis.h"
37+
#include "aocl_gemm_check.h"
38+
#include "lpgemm_types.h"
39+
#include "lpgemm_post_ops.h"
40+
#include "lpgemm_thread_decor_openmp.h"
41+
#include "lpgemm_5loop_interface_apis.h"
42+
#include "lpgemm_config.h"
43+
#include "lpgemm_utils_s8.h"
44+
#include "lpgemm_logger.h"
45+
46+
AOCL_GEMM_MATMUL(int8_t,int8_t,bfloat16,int32_t,s8s8s32obf16_sym_quant)
47+
{
48+
LPGEMM_START_LOGGER();
49+
LPGEMM_WRITE_LOGGER \
50+
(
51+
"s8s8s32obf16_sym_quant", \
52+
order, transa, transb, \
53+
m, n, k, \
54+
( ( float ) alpha ), \
55+
lda, mem_format_a, \
56+
ldb, mem_format_b, \
57+
( ( float ) beta ), \
58+
ldc, post_op_unparsed \
59+
);
60+
61+
trans_t blis_transa;
62+
trans_t blis_transb;
63+
64+
// Check if avx512_vnni ISA is supported, lpgemm matmul only works with it.
65+
if ( bli_cpuid_is_avx512vnni_supported() == FALSE )
66+
{
67+
bli_print_msg(" AVX512_VNNI ISA not supported by processor, "
68+
"cannot perform s8s8s32 gemm.", __FILE__, __LINE__ );
69+
goto err_hndl;
70+
}
71+
72+
/* Initialize BLIS. */
73+
bli_init_auto();
74+
75+
// Set MC, NC, KC, NR, MR.
76+
aocl_lpgemm_init_global_cntx();
77+
78+
// check for validity of params.
79+
int err_no = 0;
80+
AOCL_GEMM_CHECK
81+
(
82+
"s8s8s32obf16_sym_quant",
83+
order, transa, transb,
84+
m, n, k,
85+
a, lda, mem_format_a,
86+
b, ldb, mem_format_b,
87+
c, ldc,
88+
err_no
89+
);
90+
if ( err_no != 0 )
91+
{
92+
goto err_hndl;
93+
}
94+
95+
/* Map BLAS chars to their corresponding BLIS enumerated type value. */
96+
bli_param_map_netlib_to_blis_trans( transa, &blis_transa );
97+
bli_param_map_netlib_to_blis_trans( transb, &blis_transb );
98+
99+
bool is_row_major = ((order == 'r') || (order == 'R'));
100+
bool is_column_major = ((order == 'c') || (order == 'C'));
101+
102+
// Column major support disabled for int API's till micro-kernel
103+
// post-ops are updated to account for column major.
104+
if ( (is_column_major == TRUE) && (post_op_unparsed != NULL) )
105+
{
106+
bli_print_msg("Column major inputs not supported with Post-ops.",
107+
__FILE__, __LINE__);
108+
goto err_hndl;
109+
}
110+
111+
// The strides are set assuming a row major kernel.
112+
inc_t rs_a = lda;
113+
inc_t cs_a = 1;
114+
115+
if (bli_is_trans(blis_transa))
116+
{
117+
rs_a = 1;
118+
cs_a = lda;
119+
}
120+
121+
inc_t rs_b = ldb;
122+
inc_t cs_b = 1;
123+
124+
if (bli_is_trans(blis_transb))
125+
{
126+
rs_b = 1;
127+
cs_b = ldb;
128+
}
129+
const inc_t rs_c = ldc;
130+
const inc_t cs_c = 1;
131+
132+
AOCL_MEMORY_TAG mtag_a;
133+
AOCL_MEMORY_TAG mtag_b;
134+
135+
bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a);
136+
bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b);
137+
138+
// Reorder is not supported for A matrix
139+
if ((is_row_major == TRUE) && (mtag_a == REORDERED))
140+
{
141+
bli_print_msg(" Reordering of A matrix is not supported in "
142+
" row major case.", __FILE__, __LINE__);
143+
goto err_hndl;
144+
}
145+
// Inputs swapped in column major, A becomes B from kernel point of view.
146+
// Reorder is not supported for column major matrices.
147+
else if ((is_column_major == TRUE) &&
148+
((mtag_b == REORDERED) || (mtag_a == REORDERED)))
149+
{
150+
bli_print_msg(" Reordering of column major matrices is "
151+
" not supported.", __FILE__, __LINE__);
152+
goto err_hndl;
153+
}
154+
155+
// From 5-loop function point of view
156+
// B matrix needs to be packed in a certain format in order to be loaded
157+
// and used in bf16 instrution. As such the mtag_b always needs to be either
158+
// packed or reordered. B matrix as it is (unpacked) cannot be used, and
159+
// the mtag_b is set to packed to enable runtime packing.
160+
if ((is_row_major == TRUE) && (mtag_b == UNPACKED))
161+
{
162+
mtag_b = PACK;
163+
}
164+
// Inputs swapped in column major, A becomes B from kernel point of view.
165+
else if ((is_column_major == TRUE) && (mtag_a == UNPACKED))
166+
{
167+
mtag_a = PACK;
168+
}
169+
170+
// From 5-loop function point of view,
171+
// A matrix when in column major storage needs to be packed to row-major
172+
// storage as kernel expects A matrix to be in row-major format.
173+
if ((is_row_major == TRUE) && (bli_is_trans(blis_transa)))
174+
{
175+
mtag_a = PACK;
176+
}
177+
// Inputs swapped in column major, A becomes B from kernel point of view.
178+
else if ((is_column_major == TRUE) && (bli_is_trans(blis_transb)))
179+
{
180+
mtag_b = PACK;
181+
}
182+
183+
// convert group-level post-op struct to linked list format.
184+
lpgemm_group_post_op grp_post_op_list[AOCL_MAX_POST_OPS];
185+
err_t err = lpgemm_translate_to_group_postops_list
186+
(
187+
post_op_unparsed->post_op_grp, grp_post_op_list,
188+
m, n, k
189+
);
190+
191+
if( err != BLIS_SUCCESS )
192+
{
193+
goto err_hndl;
194+
}
195+
196+
// Convert post op struct to post op linked list format.
197+
lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS];
198+
err = lpgemm_translate_to_post_ops_list
199+
(
200+
post_op_unparsed, post_op_list,
201+
( void* )c, ( void* )( &order ),
202+
m, n
203+
);
204+
205+
if( err != BLIS_SUCCESS )
206+
{
207+
goto err_hndl;
208+
}
209+
210+
// Initialize a local runtime with global settings if necessary. Note
211+
// that in the case that a runtime is passed in, we make a local copy.
212+
rntm_t rntm_g;
213+
bli_rntm_init_from_global( &rntm_g );
214+
bli_pba_rntm_set_pba( &rntm_g );
215+
216+
lpgemm_cntx_t* lcntx_g = lpgemm_get_global_cntx_obj( S8S8S32OS32 );
217+
218+
#ifdef BLIS_ENABLE_OPENMP
219+
// Swapping inputs to induce row major computation for column major inputs.
220+
if (is_column_major == TRUE)
221+
{
222+
lpgemm_s8s8s32o32_sym_quant_openmp_thread_decorator
223+
(
224+
n, m, k,
225+
b, rs_b, cs_b, mtag_b,
226+
a, rs_a, cs_a, mtag_a,
227+
(float *)c, rs_c, cs_c,
228+
alpha, beta,
229+
&rntm_g, lcntx_g, grp_post_op_list,
230+
post_op_list, BF16
231+
);
232+
}
233+
else
234+
{
235+
lpgemm_s8s8s32o32_sym_quant_openmp_thread_decorator
236+
(
237+
m, n, k,
238+
a, rs_a, cs_a, mtag_a,
239+
b, rs_b, cs_b, mtag_b,
240+
(float *)c, rs_c, cs_c,
241+
alpha, beta,
242+
&rntm_g, lcntx_g, grp_post_op_list,
243+
post_op_list, BF16
244+
);
245+
}
246+
#else
247+
// Swapping inputs to induce row major computation for column major inputs.
248+
if (is_column_major == TRUE)
249+
{
250+
lpgemm_s8s8s32o32_sym_quant_thread_decorator
251+
(
252+
n, m, k,
253+
b, rs_b, cs_b, mtag_b,
254+
a, rs_a, cs_a, mtag_a,
255+
(float *)c, rs_c, cs_c,
256+
alpha, beta,
257+
&rntm_g, lcntx_g, grp_post_op_list,
258+
post_op_list, BF16);
259+
}
260+
else
261+
{
262+
lpgemm_s8s8s32o32_sym_quant_thread_decorator
263+
(
264+
m, n, k,
265+
a, rs_a, cs_a, mtag_a,
266+
b, rs_b, cs_b, mtag_b,
267+
(float *)c, rs_c, cs_c,
268+
alpha, beta,
269+
&rntm_g, lcntx_g, grp_post_op_list,
270+
post_op_list, BF16
271+
);
272+
}
273+
#endif
274+
275+
err_hndl:;
276+
LPGEMM_STOP_LOGGER();
277+
}
278+

0 commit comments

Comments
 (0)