Skip to content

Commit 8a720ed

Browse files
committed
Demonstrate pickling is not an issue anymore
1 parent 9fccfd4 commit 8a720ed

File tree

5 files changed

+22
-10
lines changed

5 files changed

+22
-10
lines changed

sklearn/tree/_oblique_splitter.pyx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ from libc.stdio cimport printf
2020
from libcpp.vector cimport vector
2121

2222
from cython.operator cimport dereference as deref
23-
from cython.parallel import prange
2423

2524
from ._utils cimport log
2625
from ._utils cimport rand_int
@@ -55,7 +54,7 @@ cdef class ObliqueSplitter(Splitter):
5554

5655
def __cinit__(self, Criterion criterion, SIZE_t max_features,
5756
SIZE_t min_samples_leaf, double min_weight_leaf,
58-
double feature_combinations, object random_state):
57+
double feature_combinations, object random_state, *argv):
5958
"""
6059
Parameters
6160
----------

sklearn/tree/_oblique_tree.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ cdef class ObliqueTree(Tree):
3131
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, Node *node, SIZE_t node_id) nogil
3232
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1
3333

34+
cpdef DTYPE_t compute_feature_value(self, object X, SIZE_t node_id)
3435
cpdef np.ndarray get_projection_matrix(self)

sklearn/tree/_oblique_tree.pyx

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ cdef class ObliqueTree(Tree):
244244
self.proj_vec_weights[node_id] = deref(deref(oblique_split_node).proj_vec_weights)
245245
self.proj_vec_indices[node_id] = deref(deref(oblique_split_node).proj_vec_indices)
246246
return 1
247+
248+
cpdef DTYPE_t compute_feature_value(self, object X, SIZE_t node_id):
249+
cdef const DTYPE_t[:] X_vector = X
250+
cdef Node* node = &self.nodes[node_id]
251+
feature_value = self._compute_feature(X_vector, node, node_id)
252+
return feature_value
247253

248254
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, Node *node, SIZE_t node_id) nogil:
249255
"""Compute feature from a given data matrix, X.
@@ -254,11 +260,19 @@ cdef class ObliqueTree(Tree):
254260
cdef vector[DTYPE_t] proj_vec_weights
255261
cdef vector[SIZE_t] proj_vec_indices
256262
cdef DTYPE_t proj_feat = 0.0
263+
cdef DTYPE_t weight = 0.0
264+
cdef SIZE_t j = 0
265+
cdef SIZE_t n_projections = proj_vec_indices.size()
257266

258267
# compute projection of the data based on trained tree
259268
proj_vec_weights = self.proj_vec_weights[node_id]
260269
proj_vec_indices = self.proj_vec_indices[node_id]
261-
for j in range(proj_vec_indices.size()):
262-
proj_feat += X_ndarray[proj_vec_indices[j]] * proj_vec_weights[j]
270+
for j in range(n_projections):
271+
weight = proj_vec_weights[j]
272+
273+
# skip a multiplication step if there is nothing to be done
274+
if weight == 0:
275+
continue
276+
proj_feat += X_ndarray[proj_vec_indices[j]] * weight
263277

264278
return proj_feat

sklearn/tree/_tree.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -766,15 +766,15 @@ cdef class Tree:
766766

767767
self.capacity = capacity
768768
return 0
769-
769+
770770
cdef int _set_node_values(self, SplitRecord* split_node,
771771
Node *node) nogil except -1:
772772
"""Set node data.
773773
"""
774774
node.feature = split_node.feature
775775
node.threshold = split_node.threshold
776776
return 1
777-
777+
778778
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray,
779779
Node *node, SIZE_t node_id) nogil:
780780
"""Compute feature from a given data matrix, X.

sklearn/tree/tests/test_tree.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ def test_pickle():
998998
else:
999999
X, y = diabetes.data, diabetes.target
10001000

1001-
est = TreeEstimator(random_state=0, max_depth=5)
1001+
est = TreeEstimator(random_state=0)
10021002
est.fit(X, y)
10031003
score = est.score(X, y)
10041004

@@ -1032,9 +1032,7 @@ def test_pickle():
10321032
est2_proj_mat = est2.tree_.get_projection_matrix()
10331033
assert_array_equal(est_proj_mat, est2_proj_mat)
10341034

1035-
# TODO: this works when `max_depth=5`, but not 6?
1036-
# Must be some machine rounding error occurring
1037-
# probably needs ``_compute_feature`` to do some rounding?
1035+
# score should match before/after pickling
10381036
score2 = est2.score(X, y)
10391037
assert (
10401038
score == score2

0 commit comments

Comments
 (0)