@@ -87,24 +87,27 @@ def _parallel_build_trees(tree, forest, X, y, sample_weight, tree_idx, n_trees,
8787 sample_counts = np .bincount (indices , minlength = n_samples )
8888 curr_sample_weight *= sample_counts
8989
90- if class_weight == 'bootstrap' :
90+ if class_weight == 'subsample' :
91+
9192 expanded_class_weight = [curr_sample_weight ]
93+
9294 for k in range (y .shape [1 ]):
9395 y_full = y [:, k ]
9496 classes_full = np .unique (y_full )
95- y_boot = y_full [indices ]
97+ y_boot = y [indices , k ]
9698 classes_boot = np .unique (y_boot )
97- # Get class weights for the bootstrap sample
98- weight_k = compute_class_weight ( 'auto' , classes_boot , y_boot )
99- # Expand class weights to cover all classes in original y
100- # (in case some were missing from the bootstrap sample)
101- weight_k = np .array ([ weight_k [ np . where ( classes_boot == c )][ 0 ]
102- if c in classes_boot
103- else 0.
104- for c in classes_full ])
99+
100+ # Get class weights for the bootstrap sample, covering all
101+ # classes in case some were missing from the bootstrap sample
102+ weight_k = np . choose (
103+ np .searchsorted ( classes_boot , classes_full ),
104+ compute_class_weight ( 'auto' , classes_boot , y_boot ),
105+ mode = 'clip' )
106+
105107 # Expand weights over the original y for this output
106108 weight_k = weight_k [np .searchsorted (classes_full , y_full )]
107109 expanded_class_weight .append (weight_k )
110+
108111 # Multiply all weights by sample & bootstrap weights
109112 curr_sample_weight = np .prod (expanded_class_weight ,
110113 axis = 0 ,
@@ -243,7 +246,7 @@ def fit(self, X, y, sample_weight=None):
243246
244247 if expanded_class_weight is not None :
245248 if sample_weight is not None :
246- sample_weight = np . copy ( sample_weight ) * expanded_class_weight
249+ sample_weight = sample_weight * expanded_class_weight
247250 else :
248251 sample_weight = expanded_class_weight
249252
@@ -428,14 +431,14 @@ def _validate_y_class_weight(self, y):
428431 self .n_classes_ .append (classes_k .shape [0 ])
429432
430433 if self .class_weight is not None :
431- valid_presets = [ 'auto' , 'bootstrap' ]
434+ valid_presets = ( 'auto' , 'subsample' )
432435 if isinstance (self .class_weight , six .string_types ):
433436 if self .class_weight not in valid_presets :
434437 raise ValueError ('Valid presets for class_weight include '
435- '"auto" and "bootstrap ". Given "%s".'
438+ '"auto" and "subsample ". Given "%s".'
436439 % self .class_weight )
437440 if self .warm_start :
438- warn ('class_weight presets "auto" or "bootstrap " are '
441+ warn ('class_weight presets "auto" or "subsample " are '
439442 'not recommended for warm_start if the fitted data '
440443 'differs from the full dataset. In order to use '
441444 '"auto" weights, use compute_class_weight("auto", '
@@ -453,7 +456,7 @@ def _validate_y_class_weight(self, y):
453456 "in class_weight should match number of "
454457 "outputs." )
455458
456- if self .class_weight != 'bootstrap ' or not self .bootstrap :
459+ if self .class_weight != 'subsample ' or not self .bootstrap :
457460 expanded_class_weight = []
458461 for k in range (self .n_outputs_ ):
459462 if self .class_weight in valid_presets :
@@ -797,7 +800,7 @@ class RandomForestClassifier(ForestClassifier):
797800 and add more estimators to the ensemble, otherwise, just fit a whole
798801 new forest.
799802
800- class_weight : dict, list of dicts, "auto", "bootstrap " or None, optional
803+ class_weight : dict, list of dicts, "auto", "subsample " or None, optional
801804
802805 Weights associated with classes in the form ``{class_label: weight}``.
803806 If not given, all classes are supposed to have weight one. For
@@ -807,7 +810,7 @@ class RandomForestClassifier(ForestClassifier):
807810 The "auto" mode uses the values of y to automatically adjust
808811 weights inversely proportional to class frequencies in the input data.
809812
810- The "bootstrap " mode is the same as "auto" except that weights are
813+ The "subsample " mode is the same as "auto" except that weights are
811814 computed based on the bootstrap sample for every tree grown.
812815
813816 For multi-output, the weights of each column of y will be multiplied.
@@ -1127,7 +1130,7 @@ class ExtraTreesClassifier(ForestClassifier):
11271130 and add more estimators to the ensemble, otherwise, just fit a whole
11281131 new forest.
11291132
1130- class_weight : dict, list of dicts, "auto", "bootstrap " or None, optional
1133+ class_weight : dict, list of dicts, "auto", "subsample " or None, optional
11311134
11321135 Weights associated with classes in the form ``{class_label: weight}``.
11331136 If not given, all classes are supposed to have weight one. For
@@ -1137,7 +1140,7 @@ class ExtraTreesClassifier(ForestClassifier):
11371140 The "auto" mode uses the values of y to automatically adjust
11381141 weights inversely proportional to class frequencies in the input data.
11391142
1140- The "bootstrap " mode is the same as "auto" except that weights are
1143+ The "subsample " mode is the same as "auto" except that weights are
11411144 computed based on the bootstrap sample for every tree grown.
11421145
11431146 For multi-output, the weights of each column of y will be multiplied.
0 commit comments