@@ -281,7 +281,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator,
281281 @abstractmethod
282282 def __init__ (self , estimator , scoring = None ,
283283 fit_params = None , n_jobs = 1 , iid = True ,
284- refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' ):
284+ refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' ,
285+ scorer_params = None ):
285286
286287 self .scoring = scoring
287288 self .estimator = estimator
@@ -292,8 +293,9 @@ def __init__(self, estimator, scoring=None,
292293 self .cv = cv
293294 self .verbose = verbose
294295 self .pre_dispatch = pre_dispatch
296+ self .scorer_params = scorer_params
295297
296- def score (self , X , y = None , sample_weight = None ):
298+ def score (self , X , y = None , ** scorer_params ):
297299 """Returns the score on the given test data and labels, if the search
298300 estimator has been refit. The ``score`` function of the best estimator
299301 is used, or the ``scoring`` parameter where unavailable.
@@ -308,24 +310,18 @@ def score(self, X, y=None, sample_weight=None):
308310 Target relative to X for classification or regression;
309311 None for unsupervised learning.
310312
311- sample_weight : array-like, shape = [n_samples], optional
312- Sample weights.
313-
314313 Returns
315314 -------
316315 score : float
317316
318317 """
319- kwargs = {}
320- if sample_weight is not None :
321- kwargs ['sample_weight' ] = sample_weight
322318 if hasattr (self .best_estimator_ , 'score' ):
323- return self .best_estimator_ .score (X , y , ** kwargs )
319+ return self .best_estimator_ .score (X , y , ** scorer_params )
324320 if self .scorer_ is None :
325321 raise ValueError ("No score function explicitly defined, "
326322 "and the estimator doesn't provide one %s"
327323 % self .best_estimator_ )
328- return self .scorer_ (self .best_estimator_ , X , y , ** kwargs )
324+ return self .scorer_ (self .best_estimator_ , X , y , ** scorer_params )
329325
330326 @property
331327 def predict (self ):
@@ -343,15 +339,15 @@ def decision_function(self):
343339 def transform (self ):
344340 return self .best_estimator_ .transform
345341
346- def _fit (self , X , y , sample_weight , parameter_iterable ):
342+ def _fit (self , X , y , parameter_iterable ):
347343 """Actual fitting, performing the search over parameters."""
348344
349345 estimator = self .estimator
350346 cv = self .cv
351347 self .scorer_ = check_scoring (self .estimator , scoring = self .scoring )
352348
353349 n_samples = _num_samples (X )
354- X , y , sample_weight = indexable (X , y , sample_weight )
350+ X , y = indexable (X , y )
355351
356352 if y is not None :
357353 if len (y ) != n_samples :
@@ -376,10 +372,10 @@ def _fit(self, X, y, sample_weight, parameter_iterable):
376372 n_jobs = self .n_jobs , verbose = self .verbose ,
377373 pre_dispatch = pre_dispatch
378374 )(
379- delayed (_fit_and_score )(clone (base_estimator ), X , y , sample_weight ,
375+ delayed (_fit_and_score )(clone (base_estimator ), X , y ,
380376 self .scorer_ , train , test ,
381377 self .verbose , parameters , self .fit_params ,
382- return_parameters = True )
378+ self . scorer_params , return_parameters = True )
383379 for parameters in parameter_iterable
384380 for train , test in cv )
385381
@@ -422,9 +418,6 @@ def _fit(self, X, y, sample_weight, parameter_iterable):
422418
423419 if self .refit :
424420 fit_params = self .fit_params
425- if sample_weight is not None :
426- fit_params = fit_params .copy ()
427- fit_params ['sample_weight' ] = sample_weight
428421 # fit the best estimator using the entire dataset
429422 # clone first to work around broken estimators
430423 best_estimator = clone (base_estimator ).set_params (
@@ -580,14 +573,15 @@ class GridSearchCV(BaseSearchCV):
580573
581574 def __init__ (self , estimator , param_grid , scoring = None ,
582575 fit_params = None , n_jobs = 1 , iid = True ,
583- refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' ):
576+ refit = True , cv = None , verbose = 0 , pre_dispatch = '2*n_jobs' ,
577+ scorer_params = None ):
584578 super (GridSearchCV , self ).__init__ (
585579 estimator , scoring , fit_params , n_jobs , iid ,
586- refit , cv , verbose , pre_dispatch )
580+ refit , cv , verbose , pre_dispatch , scorer_params )
587581 self .param_grid = param_grid
588582 _check_param_grid (param_grid )
589583
590- def fit (self , X , y = None , sample_weight = None ):
584+ def fit (self , X , y = None ):
591585 """Run fit with all sets of parameters.
592586
593587 Parameters
@@ -600,11 +594,8 @@ def fit(self, X, y=None, sample_weight=None):
600594 y : array-like, shape = [n_samples] or [n_samples, n_output], optional
601595 Target relative to X for classification or regression;
602596 None for unsupervised learning.
603-
604- sample_weight : array-like, shape = [n_samples], optional
605- Sample weights.
606597 """
607- return self ._fit (X , y , sample_weight , ParameterGrid (self .param_grid ))
598+ return self ._fit (X , y , ParameterGrid (self .param_grid ))
608599
609600
610601class RandomizedSearchCV (BaseSearchCV ):
@@ -730,17 +721,18 @@ class RandomizedSearchCV(BaseSearchCV):
730721
731722 def __init__ (self , estimator , param_distributions , n_iter = 10 , scoring = None ,
732723 fit_params = None , n_jobs = 1 , iid = True , refit = True , cv = None ,
733- verbose = 0 , pre_dispatch = '2*n_jobs' , random_state = None ):
724+ verbose = 0 , pre_dispatch = '2*n_jobs' , random_state = None ,
725+ scorer_params = None ):
734726
735727 self .param_distributions = param_distributions
736728 self .n_iter = n_iter
737729 self .random_state = random_state
738730 super (RandomizedSearchCV , self ).__init__ (
739731 estimator = estimator , scoring = scoring , fit_params = fit_params ,
740732 n_jobs = n_jobs , iid = iid , refit = refit , cv = cv , verbose = verbose ,
741- pre_dispatch = pre_dispatch )
733+ pre_dispatch = pre_dispatch , scorer_params = scorer_params )
742734
743- def fit (self , X , y = None , sample_weight = None ):
735+ def fit (self , X , y = None ):
744736 """Run fit on the estimator with randomly drawn parameters.
745737
746738 Parameters
@@ -752,12 +744,8 @@ def fit(self, X, y=None, sample_weight=None):
752744 y : array-like, shape = [n_samples] or [n_samples, n_output], optional
753745 Target relative to X for classification or regression;
754746 None for unsupervised learning.
755-
756- sample_weight : array-like, shape = [n_samples], optional
757- Sample weights.
758-
759747 """
760748 sampled_params = ParameterSampler (self .param_distributions ,
761749 self .n_iter ,
762750 random_state = self .random_state )
763- return self ._fit (X , y , sample_weight , sampled_params )
751+ return self ._fit (X , y , sampled_params )
0 commit comments