@@ -404,6 +404,18 @@ def test_copy_param_extras(self):
404404 self .assertEqual (tp ._paramMap , copied_no_extra )
405405 self .assertEqual (tp ._defaultParamMap , tp_copy ._defaultParamMap )
406406
407+ def test_logistic_regression_check_thresholds (self ):
408+ self .assertIsInstance (
409+ LogisticRegression (threshold = 0.5 , thresholds = [0.5 , 0.5 ]),
410+ LogisticRegression
411+ )
412+
413+ self .assertRaisesRegex (
414+ ValueError ,
415+ "Logistic Regression getThreshold found inconsistent.*$" ,
416+ LogisticRegression , threshold = 0.42 , thresholds = [0.5 , 0.5 ]
417+ )
418+
407419
408420class EvaluatorTests (SparkSessionTestCase ):
409421
@@ -807,18 +819,6 @@ def test_logistic_regression(self):
807819 except OSError :
808820 pass
809821
810- def test_logistic_regression_check_thresholds (self ):
811- self .assertIsInstance (
812- LogisticRegression (threshold = 0.5 , thresholds = [0.5 , 0.5 ]),
813- LogisticRegression
814- )
815-
816- self .assertRaisesRegexp (
817- ValueError ,
818- "Logistic Regression getThreshold found inconsistent.*$" ,
819- LogisticRegression , threshold = 0.42 , thresholds = [0.5 , 0.5 ]
820- )
821-
822822 def _compare_params (self , m1 , m2 , param ):
823823 """
824824 Compare 2 ML Params instances for the given param, and assert both have the same param value
0 commit comments