-
-
Notifications
You must be signed in to change notification settings - Fork 26.5k
TST use global_dtype in sklearn/cluster/tests/test_birch.py #22671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly to #22672, we need at least one test that checks the impact of changing the dtype of X on the fitted attribute subcluster_centers_. I would have expected float32 but it's not the case. I am not sure why maybe this reveals a suboptimal operation in the Birch class.
Since Birch is also a transformer, we should also check the dtype of Birch().fit_transform(X.astype(np.float32)) which I would also have expected to be float32 but it's not the case either, probably because subcluster_centers_ is always float64.
Furthermore, while trying interactively for myself, I observed that fitting a Birch() model now raises a warning:
>>> import numpy as np
>>> from sklearn.cluster import Birch
>>> Birch().fit(np.random.randn(100, 5).astype(np.float32)).subcluster_centers_.dtype
/Users/ogrisel/code/scikit-learn/sklearn/cluster/_birch.py:760: UserWarning: Some metric_kwargs have been passed ({'Y_norm_squared': array([ 3.86024464, 11.683084 , 12.99578754, 6.910336 , 2.55765101,
5.03749571, 5.06213863, 10.26015348, 0.90800463, 1.26334124,
3.4959797 , 6.45426856, 13.48504291, 3.90932701, 6.88367298,
3.43889466, 4.35950137, 6.41143755, 1.44802403, 0.48286628,
6.90164792, 3.67385496, 6.00607854, 6.95285525, 6.90960791,
3.7674752 , 3.52569363, 3.47566689, 5.2203662 , 2.10915227,
1.57974275, 1.7641581 , 3.3768002 , 4.25468386, 3.47676319,
4.86241842, 2.3451047 , 2.17838236, 6.30420973, 3.64096226,
8.83660004, 4.07342638, 2.08893818, 1.51923725, 6.9491891 ,
4.84401576, 7.78082366, 3.14570647, 3.43566494, 6.79652774,
5.56645993, 11.18789605, 4.60155353, 6.88679755, 0.88999525,
4.95820257, 2.69660257, 0.75948625, 4.14714094, 10.64599847,
3.07223409, 9.7867565 , 8.27261454, 2.71387669, 14.36863042,
10.79094887, 5.11707374, 2.67162805, 2.10645627, 2.35485549,
3.02763474, 6.7502932 , 2.70514103, 7.39961664, 3.18970976,
5.23735055, 1.956462 , 7.20984384, 12.08628175, 5.37923063])}) but aren'tusable for this case (FastEuclideanPairwiseDistancesArgKmin) and will be ignored.
self.labels_ = self._predict(X)
np.float64Ideally our test should fail when we have such unexpected UserWarnings raised by scikit-learn code but unfortunately this is not the case at the moment. We only do it for FutureWarning on some dedicated CI runs.
|
For the However it does not run on Birch: while running for either: or finds plenty of common tests. This is probably because |
|
dtype preservation for transformers is tracked in this issue: |
The second phrase tells you why it should not have the tag :) There's this long term issue #11000 to track which transformer properly preserves float32. Birch is not 1 of them yet. |
I don't understand. I think it should have the tag and the Birch code should be fixed to make sure the common test pass, no? |
Of course it should but it takes time :) |
test_birch.py to test implementations on 32bit datasets|
Actually BIRCH doesn't preserve the dtype yet so I think these changes should be delayed. |
|
According to some irl discussions, such tests should only be added after Let's keep this PR open in the mean time. |
Co-authored-by: Jérémie du Boisberranger <[email protected]>
|
Now that #22968 has been merged, I've updated this PR taking @jeremiedbb's last comments into account. |
ogrisel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming the following works as expected, LGTM.
Co-authored-by: Olivier Grisel <[email protected]>
jeremiedbb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that Birch does preserve dtype we can merge this one. LGTM.
Reference Issues/PRs
Partially addresses #22881
Precedes #22590
What does this implement/fix? Explain your changes.
This parametrizes tests from
test_birch.pyto run on 32bit datasets.Any other comments?
We could introduce a mechanism to be able to able to remove tests' execution on 32bit datasets if this takes too much time to complete.