@@ -33,54 +33,20 @@ def train_wrap(
3333 cdef model * model
3434 cdef char_const_ptr error_msg
3535 cdef int len_w
36-
37- # The implementation for float32 and float64 uses a single interface.
38- # This is done by accepting the data as a pointer to a buffer of bytes.
39- # In this regard, we define a pointer to pass the address of the first
40- # element of the buffer seen as raw bytes (hence the use of `char *`).
41- #
42- # We proceed in two steps using intermediate memory views to have Cython
43- # have sufficient typing information not to use PyObjects.
44- cdef cnp.float64_t[::1 ] X_data_64
45- cdef cnp.float32_t[::1 ] X_data_32
46- cdef char * X_data_as_bytes_ptr = NULL
47-
48- # The same is done for `indices` and `indptr` in the CSR case.
49- cdef cnp.int32_t[::1 ] X_indices
50- cdef char * X_indices_as_bytes_ptr = NULL
51-
52- cdef cnp.int32_t[::1 ] X_indptr
53- cdef char * X_indptr_as_bytes_ptr = NULL
54-
55- cdef bint X_stores_float64_data = X.dtype == np.float64
36+ cdef char * x_data_bytes_ptr
37+ cdef cnp.int32_t[::1 ] x_indices
38+ cdef cnp.int32_t[::1 ] x_indptr
39+ cdef bint x_has_type_float64 = X.dtype == np.float64
5640
5741 if is_sparse:
58- # X is a CSR matrix here, a format which stores the values
59- # as a contiguous buffer via a NumPy array in a `data` attribute.
60- # We get the address of the first element of the buffer which
61- # we reference using a pointer to bytes.
62- if X_stores_float64_data:
63- X_data_64 = X.data
64- X_data_as_bytes_ptr = < char * > & X_data_64[0 ]
65- else :
66- X_data_32 = X.data
67- X_data_as_bytes_ptr = < char * > & X_data_32[0 ]
68-
69- # Similar operations are to be performed for `indices` and `indptr`.
70- X_indices = X.indices
71- X_indices_as_bytes_ptr = < char * > & X_indices[0 ]
72-
73- X_indptr = X.indptr
74- X_indptr_as_bytes_ptr = < char * > & X_indptr[0 ]
75-
42+ x_data_bytes_ptr = _get_sparse_x_data_bytes(x = X, x_has_type_float64 = x_has_type_float64)
43+ x_indices = X.indices
44+ x_indptr = X.indptr
7645 problem = csr_set_problem(
77- # Underneath, the data will be statically re-interpreted as
78- # either float32 or float64 depending on the boolean passed as
79- # the second argument.
80- X_data_as_bytes_ptr,
81- X_stores_float64_data,
82- X_indices_as_bytes_ptr,
83- X_indptr_as_bytes_ptr,
46+ x_data_bytes_ptr,
47+ x_has_type_float64,
48+ < char * > & x_indices[0 ],
49+ < char * > & x_indptr[0 ],
8450 (< cnp.int32_t> X.shape[0 ]),
8551 (< cnp.int32_t> X.shape[1 ]),
8652 (< cnp.int32_t> X.nnz),
@@ -89,17 +55,9 @@ def train_wrap(
8955 < char * > & Y[0 ]
9056 )
9157 else :
92- # X simply is a 2D NumPy array in this case.
93- # This is reshapeable to a 1D NumPy array in O(1) (only strides are changed).
94- if X_stores_float64_data:
95- X_data_64 = X.reshape(- 1 )
96- X_data_as_bytes_ptr = < char * > & X_data_64[0 ]
97- else :
98- X_data_32 = X.reshape(- 1 )
99- X_data_as_bytes_ptr = < char * > & X_data_32[0 ]
100-
58+ x_data_bytes_ptr = _get_x_data_bytes(x = X, x_has_type_float64 = x_has_type_float64)
10159 problem = set_problem(
102- X_data_as_bytes_ptr ,
60+ x_data_bytes_ptr ,
10361 X.dtype == np.float64,
10462 (< cnp.int32_t> X.shape[0 ]),
10563 (< cnp.int32_t> X.shape[1 ]),
@@ -115,8 +73,8 @@ def train_wrap(
11573 eps,
11674 C,
11775 class_weight.shape[0 ],
118- < char * > & class_weight_label[0 ] if class_weight_label.size > 0 else NULL ,
119- < char * > & class_weight[0 ] if class_weight.size > 0 else NULL ,
76+ < char * > & class_weight_label[0 ] if class_weight_label.size > 0 else NULL ,
77+ < char * > & class_weight[0 ] if class_weight.size > 0 else NULL ,
12078 max_iter,
12179 random_seed,
12280 epsilon
@@ -168,6 +126,37 @@ def train_wrap(
168126 return w.base, n_iter.base
169127
170128
129+ cdef char * _get_sparse_x_data_bytes(object x, bint x_has_type_float64):
130+ cdef cnp.float64_t[::1 ] x_data_64
131+ cdef cnp.float32_t[::1 ] x_data_32
132+ cdef char * x_data_bytes_ptr
133+
134+ if x_has_type_float64:
135+ x_data_64 = x.data
136+ x_data_bytes_ptr = < char * > & x_data_64[0 ]
137+ else :
138+ x_data_32 = x.data
139+ x_data_bytes_ptr = < char * > & x_data_32[0 ]
140+
141+ return x_data_bytes_ptr
142+
143+
144+ cdef char * _get_x_data_bytes(object x, bint x_has_type_float64):
145+ cdef cnp.float64_t[::1 ] x_data_64
146+ cdef cnp.float32_t[::1 ] x_data_32
147+ cdef char * x_data_bytes_ptr
148+
149+ x_as_1d_array = x.reshape(- 1 )
150+ if x_has_type_float64:
151+ x_data_64 = x_as_1d_array
152+ x_data_bytes_ptr = < char * > & x_data_64[0 ]
153+ else :
154+ x_data_32 = x_as_1d_array
155+ x_data_bytes_ptr = < char * > & x_data_32[0 ]
156+
157+ return x_data_bytes_ptr
158+
159+
171160def set_verbosity_wrap (int verbosity ):
172161 """
173162 Control verbosity of libsvm library
0 commit comments