@@ -58,7 +58,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
5858from ..tree import (DecisionTreeClassifier , DecisionTreeRegressor ,
5959 ExtraTreeClassifier , ExtraTreeRegressor )
6060from ..tree ._tree import DTYPE , DOUBLE
61- from ..utils import check_random_state , check_array
61+ from ..utils import check_random_state , check_array , compute_class_weight
6262from ..utils .validation import DataConversionWarning
6363from .base import BaseEnsemble , _partition_estimators
6464
@@ -122,7 +122,8 @@ def __init__(self,
122122 n_jobs = 1 ,
123123 random_state = None ,
124124 verbose = 0 ,
125- warm_start = False ):
125+ warm_start = False ,
126+ class_weight = None ):
126127 super (BaseForest , self ).__init__ (
127128 base_estimator = base_estimator ,
128129 n_estimators = n_estimators ,
@@ -134,6 +135,7 @@ def __init__(self,
134135 self .random_state = random_state
135136 self .verbose = verbose
136137 self .warm_start = warm_start
138+ self .class_weight = class_weight
137139
138140 def apply (self , X ):
139141 """Apply trees in the forest to X, return leaf indices.
@@ -211,11 +213,17 @@ def fit(self, X, y, sample_weight=None):
211213
212214 self .n_outputs_ = y .shape [1 ]
213215
214- y = self ._validate_y (y )
216+ y , cw = self ._validate_y_cw (y )
215217
216218 if getattr (y , "dtype" , None ) != DOUBLE or not y .flags .contiguous :
217219 y = np .ascontiguousarray (y , dtype = DOUBLE )
218220
221+ if cw is not None :
222+ if sample_weight is not None :
223+ sample_weight *= cw
224+ else :
225+ sample_weight = cw
226+
219227 # Check parameters
220228 self ._validate_estimator ()
221229
@@ -279,9 +287,9 @@ def fit(self, X, y, sample_weight=None):
279287 def _set_oob_score (self , X , y ):
280288 """Calculate out of bag predictions and score."""
281289
282- def _validate_y (self , y ):
290+ def _validate_y_cw (self , y ):
283291 # Default implementation
284- return y
292+ return y , None
285293
286294 @property
287295 def feature_importances_ (self ):
@@ -320,7 +328,8 @@ def __init__(self,
320328 n_jobs = 1 ,
321329 random_state = None ,
322330 verbose = 0 ,
323- warm_start = False ):
331+ warm_start = False ,
332+ class_weight = None ):
324333
325334 super (ForestClassifier , self ).__init__ (
326335 base_estimator ,
@@ -331,7 +340,8 @@ def __init__(self,
331340 n_jobs = n_jobs ,
332341 random_state = random_state ,
333342 verbose = verbose ,
334- warm_start = warm_start )
343+ warm_start = warm_start ,
344+ class_weight = class_weight )
335345
336346 def _set_oob_score (self , X , y ):
337347 """Compute out-of-bag score"""
@@ -377,8 +387,9 @@ def _set_oob_score(self, X, y):
377387
378388 self .oob_score_ = oob_score / self .n_outputs_
379389
380- def _validate_y (self , y ):
381- y = np .copy (y )
390+ def _validate_y_cw (self , y_org ):
391+ y = np .copy (y_org )
392+ cw = None
382393
383394 self .classes_ = []
384395 self .n_classes_ = []
@@ -388,7 +399,19 @@ def _validate_y(self, y):
388399 self .classes_ .append (classes_k )
389400 self .n_classes_ .append (classes_k .shape [0 ])
390401
391- return y
402+ if self .class_weight is not None :
403+ if self .n_outputs_ == 1 :
404+ cw = compute_class_weight (self .class_weight ,
405+ self .classes_ [0 ],
406+ y_org [:, 0 ])
407+ cw = cw [np .searchsorted (self .classes_ [0 ], y_org [:, 0 ])]
408+ else :
409+ raise NotImplementedError ('class_weights are not supported '
410+ 'for multi-output. You may use '
411+ 'sample_weights in the fit method '
412+ 'to weight by sample.' )
413+
414+ return y , cw
392415
393416 def predict (self , X ):
394417 """Predict class for X.
@@ -707,6 +730,18 @@ class RandomForestClassifier(ForestClassifier):
707730 and add more estimators to the ensemble, otherwise, just fit a whole
708731 new forest.
709732
733+ class_weight : dict, {class_label: weight} or "auto" or None, optional
734+ Weights associated with classes. If not given, all classes
735+ are supposed to have weight one.
736+
737+ The "auto" mode uses the values of y to automatically adjust
738+ weights inversely proportional to class frequencies.
739+
740+ Note that this is only supported for single-output classification.
741+
742+ Note that these weights will be multiplied with class_weight (passed
743+ through the fit method) if sample_weight is specified
744+
710745 Attributes
711746 ----------
712747 estimators_ : list of DecisionTreeClassifier
@@ -755,7 +790,8 @@ def __init__(self,
755790 n_jobs = 1 ,
756791 random_state = None ,
757792 verbose = 0 ,
758- warm_start = False ):
793+ warm_start = False ,
794+ class_weight = None ):
759795 super (RandomForestClassifier , self ).__init__ (
760796 base_estimator = DecisionTreeClassifier (),
761797 n_estimators = n_estimators ,
@@ -768,7 +804,8 @@ def __init__(self,
768804 n_jobs = n_jobs ,
769805 random_state = random_state ,
770806 verbose = verbose ,
771- warm_start = warm_start )
807+ warm_start = warm_start ,
808+ class_weight = class_weight )
772809
773810 self .criterion = criterion
774811 self .max_depth = max_depth
@@ -1017,6 +1054,18 @@ class ExtraTreesClassifier(ForestClassifier):
10171054 and add more estimators to the ensemble, otherwise, just fit a whole
10181055 new forest.
10191056
1057+ class_weight : dict, {class_label: weight} or "auto" or None, optional
1058+ Weights associated with classes. If not given, all classes
1059+ are supposed to have weight one.
1060+
1061+ The "auto" mode uses the values of y to automatically adjust
1062+ weights inversely proportional to class frequencies.
1063+
1064+ Note that this is only supported for single-output classification.
1065+
1066+ Note that these weights will be multiplied with class_weight (passed
1067+ through the fit method) if sample_weight is specified
1068+
10201069 Attributes
10211070 ----------
10221071 estimators_ : list of DecisionTreeClassifier
@@ -1068,7 +1117,8 @@ def __init__(self,
10681117 n_jobs = 1 ,
10691118 random_state = None ,
10701119 verbose = 0 ,
1071- warm_start = False ):
1120+ warm_start = False ,
1121+ class_weight = None ):
10721122 super (ExtraTreesClassifier , self ).__init__ (
10731123 base_estimator = ExtraTreeClassifier (),
10741124 n_estimators = n_estimators ,
@@ -1080,7 +1130,8 @@ def __init__(self,
10801130 n_jobs = n_jobs ,
10811131 random_state = random_state ,
10821132 verbose = verbose ,
1083- warm_start = warm_start )
1133+ warm_start = warm_start ,
1134+ class_weight = class_weight )
10841135
10851136 self .criterion = criterion
10861137 self .max_depth = max_depth
0 commit comments