@@ -12,8 +12,11 @@ from libc.math cimport fabs, sqrt
1212cimport numpy as cnp
1313import numpy as np
1414from cython cimport floating
15+ from cython.parallel cimport prange
1516from numpy.math cimport isnan
1617
18+ from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
19+
1720cnp.import_array()
1821
1922ctypedef fused integral:
@@ -27,13 +30,14 @@ def csr_row_norms(X):
2730 """ Squared L2 norm of each row in CSR matrix X."""
2831 if X.dtype not in [np.float32, np.float64]:
2932 X = X.astype(np.float64)
30- return _csr_row_norms(X.data, X.indices, X.indptr)
33+ n_threads = _openmp_effective_n_threads()
34+ return _sqeuclidean_row_norms_sparse(X.data, X.indptr, n_threads)
3135
3236
33- def _csr_row_norms (
37+ def _sqeuclidean_row_norms_sparse (
3438 const floating[::1] X_data ,
35- const integral[::1] X_indices ,
3639 const integral[::1] X_indptr ,
40+ int n_threads ,
3741):
3842 cdef:
3943 integral n_samples = X_indptr.shape[0 ] - 1
@@ -42,14 +46,13 @@ def _csr_row_norms(
4246
4347 dtype = np.float32 if floating is float else np.float64
4448
45- cdef floating[::1 ] norms = np.zeros(n_samples, dtype = dtype)
49+ cdef floating[::1 ] squared_row_norms = np.zeros(n_samples, dtype = dtype)
4650
47- with nogil:
48- for i in range (n_samples):
49- for j in range (X_indptr[i], X_indptr[i + 1 ]):
50- norms[i] += X_data[j] * X_data[j]
51+ for i in prange(n_samples, schedule = ' static' , nogil = True , num_threads = n_threads):
52+ for j in range (X_indptr[i], X_indptr[i + 1 ]):
53+ squared_row_norms[i] += X_data[j] * X_data[j]
5154
52- return np.asarray(norms )
55+ return np.asarray(squared_row_norms )
5356
5457
5558def csr_mean_variance_axis0 (X , weights = None , return_sum_weights = False ):
0 commit comments