Skip to content

Commit bab265a

Browse files
committed
TST test_rescale_data with sparse y
1 parent c9fcb47 commit bab265a

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

sklearn/linear_model/tests/test_base.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,8 @@ def test_dtype_preprocess_data(global_random_seed):
661661

662662

663663
@pytest.mark.parametrize("n_targets", [None, 2])
664-
def test_rescale_data_dense(n_targets, global_random_seed):
664+
@pytest.mark.parametrize("sparse_data", [True, False])
665+
def test_rescale_data(n_targets, sparse_data, global_random_seed):
665666
rng = np.random.RandomState(global_random_seed)
666667
n_samples = 200
667668
n_features = 2
@@ -672,14 +673,34 @@ def test_rescale_data_dense(n_targets, global_random_seed):
672673
y = rng.rand(n_samples)
673674
else:
674675
y = rng.rand(n_samples, n_targets)
675-
rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight)
676-
rescaled_X2 = X * sqrt_sw[:, np.newaxis]
676+
677+
expected_sqrt_sw = np.sqrt(sample_weight)
678+
expected_rescaled_X = X * expected_sqrt_sw[:, np.newaxis]
679+
677680
if n_targets is None:
678-
rescaled_y2 = y * sqrt_sw
681+
expected_rescaled_y = y * expected_sqrt_sw
679682
else:
680-
rescaled_y2 = y * sqrt_sw[:, np.newaxis]
681-
assert_array_almost_equal(rescaled_X, rescaled_X2)
682-
assert_array_almost_equal(rescaled_y, rescaled_y2)
683+
expected_rescaled_y = y * expected_sqrt_sw[:, np.newaxis]
684+
685+
if sparse_data:
686+
X = sparse.csr_matrix(X)
687+
if n_targets is None:
688+
y = sparse.csr_matrix(y.reshape(-1, 1))
689+
else:
690+
y = sparse.csr_matrix(y)
691+
692+
rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight)
693+
694+
assert_allclose(sqrt_sw, expected_sqrt_sw)
695+
696+
if sparse_data:
697+
rescaled_X = rescaled_X.toarray()
698+
rescaled_y = rescaled_y.toarray()
699+
if n_targets is None:
700+
rescaled_y = rescaled_y.ravel()
701+
702+
assert_allclose(rescaled_X, expected_rescaled_X)
703+
assert_allclose(rescaled_y, expected_rescaled_y)
683704

684705

685706
def test_fused_types_make_dataset():

0 commit comments

Comments
 (0)