@@ -175,6 +175,8 @@ cdef class GEMMTermComputer{{name_suffix}}:
175175 ITYPE_t thread_num,
176176 ) nogil:
177177 cdef:
178+ const {{INPUT_DTYPE_t}}[:, ::1] X_c = self.X[X_start:X_end, :]
179+ const {{INPUT_DTYPE_t}}[:, ::1] Y_c = self.Y[Y_start:Y_end, :]
178180 DTYPE_t *dist_middle_terms = self.dist_middle_terms_chunks[thread_num].data()
179181
180182 # Careful: LDA, LDB and LDC are given for F-ordered arrays
@@ -185,9 +187,9 @@ cdef class GEMMTermComputer{{name_suffix}}:
185187 BLAS_Order order = RowMajor
186188 BLAS_Trans ta = NoTrans
187189 BLAS_Trans tb = Trans
188- ITYPE_t m = X_end - X_start
189- ITYPE_t n = Y_end - Y_start
190- ITYPE_t K = self.n_features
190+ ITYPE_t m = X_c.shape[0]
191+ ITYPE_t n = Y_c.shape[0]
192+ ITYPE_t K = X_c.shape[1]
191193 DTYPE_t alpha = - 2.
192194{{if upcast_to_float64}}
193195 DTYPE_t * A = self.X_c_upcast[thread_num].data()
@@ -196,15 +198,15 @@ cdef class GEMMTermComputer{{name_suffix}}:
196198 # Casting for A and B to remove the const is needed because APIs exposed via
197199 # scipy.linalg.cython_blas aren't reflecting the arguments' const qualifier.
198200 # See: https://github.com/scipy/scipy/issues/14262
199- DTYPE_t * A = <DTYPE_t *> &self.X[X_start , 0]
200- DTYPE_t * B = <DTYPE_t *> &self.Y[Y_start , 0]
201+ DTYPE_t * A = <DTYPE_t *> &X_c[0 , 0]
202+ DTYPE_t * B = <DTYPE_t *> &Y_c[0 , 0]
201203{{endif}}
202- ITYPE_t lda = self.n_features
203- ITYPE_t ldb = self.n_features
204+ ITYPE_t lda = X_c.shape[1]
205+ ITYPE_t ldb = X_c.shape[1]
204206 DTYPE_t beta = 0.
205- ITYPE_t ldc = Y_end - Y_start
207+ ITYPE_t ldc = Y_c.shape[0]
206208
207- # dist_middle_terms = `-2 * X[X_start:X_end] @ Y[Y_start:Y_end] .T`
209+ # dist_middle_terms = `-2 * X_c @ Y_c .T`
208210 _gemm(order, ta, tb, m, n, K, alpha, A, lda, B, ldb, beta, dist_middle_terms, ldc)
209211
210212 return dist_middle_terms
0 commit comments