Skip to content

Commit 8ff81e5

Browse files
committed
Improve maintability by simplifying branching at the cost of ~10-15% performance in the fit method (might want to revert)
1 parent ab6d4ad commit 8ff81e5

File tree

1 file changed

+19
-29
lines changed

1 file changed

+19
-29
lines changed

sklearn/neighbors/binary_tree.pxi

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,21 +2152,16 @@ cdef class BinaryTree:
21522152
global_log_bound_spread =\
21532153
logsubexp(global_log_bound_spread,
21542154
node_log_bound_spreads[i_node])
2155-
if with_sample_weight:
2156-
for i in range(node_info.idx_start, node_info.idx_end):
2157-
dist_pt = self.dist(pt, data + n_features * idx_array[i],
2158-
n_features)
2159-
log_density = compute_log_kernel(dist_pt, h, kernel)
2155+
for i in range(node_info.idx_start, node_info.idx_end):
2156+
dist_pt = self.dist(pt, data + n_features * idx_array[i],
2157+
n_features)
2158+
log_density = compute_log_kernel(dist_pt, h, kernel)
2159+
if with_sample_weight:
21602160
log_weight = np.log(sample_weight[idx_array[i]])
2161-
global_log_min_bound = logaddexp(global_log_min_bound,
2162-
log_density + log_weight)
2163-
else:
2164-
for i in range(node_info.idx_start, node_info.idx_end):
2165-
dist_pt = self.dist(pt, data + n_features * idx_array[i],
2166-
n_features)
2167-
log_density = compute_log_kernel(dist_pt, h, kernel)
2168-
global_log_min_bound = logaddexp(global_log_min_bound,
2169-
log_density)
2161+
else:
2162+
log_weight = 0.
2163+
global_log_min_bound = logaddexp(global_log_min_bound,
2164+
log_density + log_weight)
21702165

21712166
#------------------------------------------------------------
21722167
# Case 4: split node and query subnodes
@@ -2294,22 +2289,17 @@ cdef class BinaryTree:
22942289
local_log_min_bound)
22952290
global_log_bound_spread[0] = logsubexp(global_log_bound_spread[0],
22962291
local_log_bound_spread)
2297-
if with_sample_weight:
2298-
for i in range(node_info.idx_start, node_info.idx_end):
2299-
dist_pt = self.dist(pt, (data + n_features * idx_array[i]),
2300-
n_features)
2301-
log_dens_contribution = compute_log_kernel(dist_pt, h, kernel)
2292+
for i in range(node_info.idx_start, node_info.idx_end):
2293+
dist_pt = self.dist(pt, (data + n_features * idx_array[i]),
2294+
n_features)
2295+
log_dens_contribution = compute_log_kernel(dist_pt, h, kernel)
2296+
if with_sample_weight:
23022297
log_weight = np.log(sample_weight[idx_array[i]])
2303-
global_log_min_bound[0] = logaddexp(global_log_min_bound[0],
2304-
log_dens_contribution +
2305-
log_weight)
2306-
else:
2307-
for i in range(node_info.idx_start, node_info.idx_end):
2308-
dist_pt = self.dist(pt, (data + n_features * idx_array[i]),
2309-
n_features)
2310-
log_dens_contribution = compute_log_kernel(dist_pt, h, kernel)
2311-
global_log_min_bound[0] = logaddexp(global_log_min_bound[0],
2312-
log_dens_contribution)
2298+
else:
2299+
log_weight = 0.
2300+
global_log_min_bound[0] = logaddexp(global_log_min_bound[0],
2301+
log_dens_contribution +
2302+
log_weight)
23132303

23142304
#------------------------------------------------------------
23152305
# Case 4: split node and query subnodes

0 commit comments

Comments
 (0)