@@ -622,11 +622,13 @@ cdef class GEMMTermComputer{{bitness}}:
622622 ITYPE_t chunks_n_threads
623623 ITYPE_t dist_middle_terms_chunks_size
624624 ITYPE_t n_features
625+ ITYPE_t chunk_size
625626
626627 # Buffers for the `-2 * X_c @ Y_c.T` term computed via GEMM
627628 vector[vector[DTYPE_t]] dist_middle_terms_chunks
628629
629630{{if need_upcast}}
631+ # Buffers for upcasting chunks of X and Y from 32bit to 64bit
630632 vector[vector[DTYPE_t]] X_c_upcast
631633 vector[vector[DTYPE_t]] Y_c_upcast
632634{{endif}}
@@ -638,24 +640,28 @@ cdef class GEMMTermComputer{{bitness}}:
638640 ITYPE_t chunks_n_threads,
639641 ITYPE_t dist_middle_terms_chunks_size,
640642 ITYPE_t n_features,
643+ ITYPE_t chunk_size,
641644 ):
642645 self.X = X
643646 self.Y = Y
644647 self.effective_n_threads = effective_n_threads
645648 self.chunks_n_threads = chunks_n_threads
646649 self.dist_middle_terms_chunks_size = dist_middle_terms_chunks_size
647650 self.n_features = n_features
651+ self.chunk_size = chunk_size
648652
649653 self.dist_middle_terms_chunks = vector[vector[DTYPE_t]](self.effective_n_threads)
650654
651655{{if need_upcast}}
656+ # We populate the buffer for upcasting chunks of X and Y from 32bit to 64bit.
652657 self.X_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
653658 self.Y_c_upcast = vector[vector[DTYPE_t]](self.effective_n_threads)
654659
655- # Buffers for upcasting chunks of X and Y from 32bit to 64bit.
660+ upcast_buffer_n_elements = self.chunk_size * n_features
661+
656662 for thread_num in range(self.effective_n_threads):
657- self.X_c_upcast[thread_num].resize(self.dist_middle_terms_chunks_size )
658- self.Y_c_upcast[thread_num].resize(self.dist_middle_terms_chunks_size )
663+ self.X_c_upcast[thread_num].resize(upcast_buffer_n_elements )
664+ self.Y_c_upcast[thread_num].resize(upcast_buffer_n_elements )
659665{{endif}}
660666
661667
@@ -1556,7 +1562,8 @@ cdef class FastEuclideanPairwiseDistancesArgKmin{{bitness}}(PairwiseDistancesArg
15561562 self.effective_n_threads,
15571563 self.chunks_n_threads,
15581564 dist_middle_terms_chunks_size,
1559- n_features=datasets_pair.X.shape[1]
1565+ n_features=datasets_pair.X.shape[1],
1566+ chunk_size=self.chunk_size,
15601567 )
15611568
15621569 if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
@@ -2171,7 +2178,8 @@ cdef class FastEuclideanPairwiseDistancesRadiusNeighborhood{{bitness}}(PairwiseD
21712178 self.effective_n_threads,
21722179 self.chunks_n_threads,
21732180 dist_middle_terms_chunks_size,
2174- n_features=datasets_pair.X.shape[1]
2181+ n_features=datasets_pair.X.shape[1],
2182+ chunk_size=self.chunk_size,
21752183 )
21762184
21772185 if metric_kwargs is not None and "Y_norm_squared" in metric_kwargs:
0 commit comments