@@ -661,7 +661,8 @@ def test_dtype_preprocess_data(global_random_seed):
661661
662662
663663@pytest .mark .parametrize ("n_targets" , [None , 2 ])
664- def test_rescale_data_dense (n_targets , global_random_seed ):
664+ @pytest .mark .parametrize ("sparse_data" , [True , False ])
665+ def test_rescale_data (n_targets , sparse_data , global_random_seed ):
665666 rng = np .random .RandomState (global_random_seed )
666667 n_samples = 200
667668 n_features = 2
@@ -672,14 +673,34 @@ def test_rescale_data_dense(n_targets, global_random_seed):
672673 y = rng .rand (n_samples )
673674 else :
674675 y = rng .rand (n_samples , n_targets )
675- rescaled_X , rescaled_y , sqrt_sw = _rescale_data (X , y , sample_weight )
676- rescaled_X2 = X * sqrt_sw [:, np .newaxis ]
676+
677+ expected_sqrt_sw = np .sqrt (sample_weight )
678+ expected_rescaled_X = X * expected_sqrt_sw [:, np .newaxis ]
679+
677680 if n_targets is None :
678- rescaled_y2 = y * sqrt_sw
681+ expected_rescaled_y = y * expected_sqrt_sw
679682 else :
680- rescaled_y2 = y * sqrt_sw [:, np .newaxis ]
681- assert_array_almost_equal (rescaled_X , rescaled_X2 )
682- assert_array_almost_equal (rescaled_y , rescaled_y2 )
683+ expected_rescaled_y = y * expected_sqrt_sw [:, np .newaxis ]
684+
685+ if sparse_data :
686+ X = sparse .csr_matrix (X )
687+ if n_targets is None :
688+ y = sparse .csr_matrix (y .reshape (- 1 , 1 ))
689+ else :
690+ y = sparse .csr_matrix (y )
691+
692+ rescaled_X , rescaled_y , sqrt_sw = _rescale_data (X , y , sample_weight )
693+
694+ assert_allclose (sqrt_sw , expected_sqrt_sw )
695+
696+ if sparse_data :
697+ rescaled_X = rescaled_X .toarray ()
698+ rescaled_y = rescaled_y .toarray ()
699+ if n_targets is None :
700+ rescaled_y = rescaled_y .ravel ()
701+
702+ assert_allclose (rescaled_X , expected_rescaled_X )
703+ assert_allclose (rescaled_y , expected_rescaled_y )
683704
684705
685706def test_fused_types_make_dataset ():
0 commit comments