MNT Add option to raise when all sample weights are 0 in _check_sample_weight#32212
Conversation
|
#31775 to be merged before ready for review |
lucyleeow
left a comment
There was a problem hiding this comment.
Thanks for the PR. Small nits only.
Just a note, please make sure lines, even in rst files, are <88 characters in length
10b0788 to
eb35131
Compare
|
@lucyleeow thanks for the recommendations! I agree with all of them and have updated the PR accordingly :) |
|
Just noticed that the default is to error, i.e. In which case we have changed the behaviour everywhere that We will need to add a whats new entry to advise of this change in behaviour. It will be many estimators and metrics affected. @ogrisel do we want to list all of them in maybe |
|
+1 for extending a common test to check for invalid sample weight related error messages and global level changed model changelog entry. |
|
I added a common test for all zero sample weights, but came across some edge cases that I need to investigate further. If you run
I'm not sure what's going on with |
1e80aba to
a2cd6c0
Compare
I figured out the issues I raised yesterday:
I've solved each of these with some minor tweaks. Thanks for your patience while I figured this out! |
d391eac to
b104e63
Compare
|
@lucyleeow I think we're ready for final reviews before merging, but before we do that I'm going to add you as a co-author given you were the one who outlined the approach I implemented. |
lucyleeow
left a comment
There was a problem hiding this comment.
Small comments but pretty much LGTM
| ineffective sampling during fitting. This change applies to all estimators that | ||
| support the parameter `sample_weight`. This change also affects metrics that validate | ||
| sample weights. |
There was a problem hiding this comment.
Are there any metrics that support sample_weight but do not validate sample weights?
There was a problem hiding this comment.
Short answer is yes, assuming your definition of "validate" means sample_weight goes through _check_sample_weight. The one example I found was r2_score in the metrics package. Although it does apply utils.validation.column_or_1d on sample_weight, it does not apply _check_sample_weight.
Given the above example, I think the statement "This change also affects metrics that validate sample weights" still applies.
|
@ogrisel may be interested to take a look? |
_check_sample_weight
|
@j-hendricks just checking if you are still interested in working on this? |
@lucyleeow Yup! Working on it right now |
b037bba to
212606e
Compare
lucyleeow
left a comment
There was a problem hiding this comment.
The CI failure is at check_classifiers_one_label_sample_weights for RandomForestClassifier(n_estimators=5)
The data looks like:
X_train = rnd.uniform(size=(10, 10))
X_test = rnd.uniform(size=(10, 10))
y = np.arange(10) % 2
sample_weight = y.copy() # select a single classWith sample size of 10, and half of those being 0 sample weight, just by chance we have a case where for one tree, all subsampled samples have sample weight of 0. Thus we end up raising this new error we added, instead of the one we are checking for.
CI test output
../sklearn/ensemble/_forest.py:188: in _parallel_build_trees
tree._fit(
X = array([[0.5488135 , 0.71518934, 0.60276335, 0.5448832 , 0.4236548 ,
0.6458941 , 0.4375872 , 0.891773 , 0.9636...787, 0.7163272 , 0.2894061 ,
0.18319136, 0.5865129 , 0.02010755, 0.82894003, 0.00469548]],
dtype=float32)
bootstrap = True
class_weight = None
curr_sample_weight = array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
This test is run on all classifiers so we should be careful in amending - to not make the test suite too computationally expensive or affect tests on other estimators. We could:
- reduce the number of 0 sample weights. This would also make imbalanced classes, as we want the 0 sample weights to all correspond to one class, but as we are only testing the error message, this should be okay (?)
- increase the sample size so all 0 sample weights is less likely
cc @ogrisel
sklearn/utils/estimator_checks.py
Outdated
| """The following estimators have custom error messages: | ||
|
|
||
| NuSVC: Invalid input - all samples have zero or negative weights. | ||
|
|
||
| Perceptron: The sample weights for validation set are all zero, consider using a | ||
| different random state. | ||
|
|
||
| SGDClassifier: The sample weights for validation set are all zero, consider using a | ||
| different random state. | ||
| """ |
There was a problem hiding this comment.
And let's specify that; all others will output "Sample weights must contain at least one non-zero number." message from _check_sample_weights.
| # Skip check that validation weights are not all zero when `early_stopping` is | ||
| # set to True as `_make_validation_split` will raise a more informative error. | ||
| sample_weight = _check_sample_weight( | ||
| sample_weight, | ||
| X, | ||
| dtype=X.dtype, | ||
| allow_all_zero_weights=self.early_stopping, | ||
| ) |
There was a problem hiding this comment.
For second reviewer: I don't love this but I can't think of a better way to force raise of the more informative error message.
We could do that. Or we could just add this particular common test to |
212606e to
e32b3c5
Compare
ogrisel
left a comment
There was a problem hiding this comment.
I did another pass with the latest changes and this looks good to merge. No need to wait for the concurrent RF fix.
…le_weight` (scikit-learn#32212) Co-authored-by: John Hendricks <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
…le_weight` (scikit-learn#32212) Co-authored-by: John Hendricks <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
Reference Issues/PRs
Fixes #31032.
What does this implement/fix? Explain your changes.
Make
_weighted_percentilereturn nan and_check_sample_weightraise error when all sample weights are 0.Previously,
_weighted_percentilewould return the last element in the array, which was unexpected behavior and unintuitive to the user. To ensure this issue is caught further upstream,_check_sample_weightnow raises a ValueError when sample weights are all 0.Additionally, parameter
allow_zero_weightswas added to_check_sample_weightfor additional flexibility regarding the raising of the ValueError.Any other comments?
Modified tests in
utils/tests/test_stats.pyandutils/tests/test_validation.pyto check for these changes.