@@ -111,6 +111,74 @@ gemv(int typenum, enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE trans,
111111}
112112
113113
114+ /*
115+ * Helper: dispatch to appropriate cblas_?syrk for typenum.
116+ */
117+ static void
118+ syrk (int typenum , enum CBLAS_ORDER order , enum CBLAS_TRANSPOSE trans ,
119+ int n , int k ,
120+ PyArrayObject * A , int lda , PyArrayObject * R )
121+ {
122+ const void * Adata = PyArray_DATA (A );
123+ void * Rdata = PyArray_DATA (R );
124+ int ldc = PyArray_DIM (R , 1 ) > 1 ? PyArray_DIM (R , 1 ) : 1 ;
125+
126+ npy_intp i ;
127+ npy_intp j ;
128+
129+ switch (typenum ) {
130+ case NPY_DOUBLE :
131+ cblas_dsyrk (order , CblasUpper , trans , n , k , 1. ,
132+ Adata , lda , 0. , Rdata , ldc );
133+
134+ for (i = 0 ; i < n ; i ++ )
135+ {
136+ for (j = i + 1 ; j < n ; j ++ )
137+ {
138+ * ((npy_double * )PyArray_GETPTR2 (R , j , i )) = * ((npy_double * )PyArray_GETPTR2 (R , i , j ));
139+ }
140+ }
141+ break ;
142+ case NPY_FLOAT :
143+ cblas_ssyrk (order , CblasUpper , trans , n , k , 1.f ,
144+ Adata , lda , 0.f , Rdata , ldc );
145+
146+ for (i = 0 ; i < n ; i ++ )
147+ {
148+ for (j = i + 1 ; j < n ; j ++ )
149+ {
150+ * ((npy_float * )PyArray_GETPTR2 (R , j , i )) = * ((npy_float * )PyArray_GETPTR2 (R , i , j ));
151+ }
152+ }
153+ break ;
154+ case NPY_CDOUBLE :
155+ cblas_zsyrk (order , CblasUpper , trans , n , k , oneD ,
156+ Adata , lda , zeroD , Rdata , ldc );
157+
158+ for (i = 0 ; i < n ; i ++ )
159+ {
160+ for (j = i + 1 ; j < n ; j ++ )
161+ {
162+ * ((npy_cdouble * )PyArray_GETPTR2 (R , j , i )) = * ((npy_cdouble * )PyArray_GETPTR2 (R , i , j ));
163+ }
164+ }
165+ break ;
166+ case NPY_CFLOAT :
167+ cblas_csyrk (order , CblasUpper , trans , n , k , oneF ,
168+ Adata , lda , zeroF , Rdata , ldc );
169+
170+ for (i = 0 ; i < n ; i ++ )
171+ {
172+ for (j = i + 1 ; j < n ; j ++ )
173+ {
174+ * ((npy_cfloat * )PyArray_GETPTR2 (R , j , i )) = * ((npy_cfloat * )PyArray_GETPTR2 (R , i , j ));
175+ }
176+ }
177+ break ;
178+ }
179+ }
180+
181+
114182typedef enum {_scalar , _column , _row , _matrix } MatrixShape ;
115183
116184
@@ -647,7 +715,34 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
647715 Trans2 = CblasTrans ;
648716 ldb = (PyArray_DIM (ap2 , 0 ) > 1 ? PyArray_DIM (ap2 , 0 ) : 1 );
649717 }
650- gemm (typenum , Order , Trans1 , Trans2 , L , N , M , ap1 , lda , ap2 , ldb , ret );
718+
719+ /*
720+ * Use syrk if we have a case of a matrix times its transpose.
721+ * Otherwise, use gemm for all other cases.
722+ */
723+ if (
724+ (PyArray_BYTES (ap1 ) == PyArray_BYTES (ap2 )) &&
725+ (PyArray_DIM (ap1 , 0 ) == PyArray_DIM (ap2 , 1 )) &&
726+ (PyArray_DIM (ap1 , 1 ) == PyArray_DIM (ap2 , 0 )) &&
727+ (PyArray_STRIDE (ap1 , 0 ) == PyArray_STRIDE (ap2 , 1 )) &&
728+ (PyArray_STRIDE (ap1 , 1 ) == PyArray_STRIDE (ap2 , 0 )) &&
729+ ((Trans1 == CblasTrans ) ^ (Trans2 == CblasTrans )) &&
730+ ((Trans1 == CblasNoTrans ) ^ (Trans2 == CblasNoTrans ))
731+ )
732+ {
733+ if (Trans1 == CblasNoTrans )
734+ {
735+ syrk (typenum , Order , Trans1 , N , M , ap1 , lda , ret );
736+ }
737+ else
738+ {
739+ syrk (typenum , Order , Trans1 , N , M , ap2 , ldb , ret );
740+ }
741+ }
742+ else
743+ {
744+ gemm (typenum , Order , Trans1 , Trans2 , L , N , M , ap1 , lda , ap2 , ldb , ret );
745+ }
651746 NPY_END_ALLOW_THREADS ;
652747 }
653748
0 commit comments