Add float cast in GeneralizedRCNN normalize#3238
Add float cast in GeneralizedRCNN normalize#3238Wadaboa wants to merge 3 commits intopytorch:masterfrom Wadaboa:fix/rcnn-int-normalize
Conversation
|
Hi @Wadaboa! Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
| return image_list, targets | ||
|
|
||
| def normalize(self, image): | ||
| if image.dtype not in (torch.float, torch.double, torch.half): |
There was a problem hiding this comment.
Your code is valid but at TorchVision we use the following idiom:
| if image.dtype not in (torch.float, torch.double, torch.half): | |
| if image.is_floating_point(): |
|
|
||
| def normalize(self, image): | ||
| if image.dtype not in (torch.float, torch.double, torch.half): | ||
| image = image.float() |
There was a problem hiding this comment.
Again, your code is correct but same as above:
| image = image.float() | |
| image = image.to(torch.float32) |
|
I don't know if this is the right place to ask, but I would like to know why the |
It's because the other normalize can't handle the targets. |
| # check that the resulting images have float32 dtype | ||
| self.assertTrue(image_list.tensors.dtype == torch.float32) | ||
| # check that no NaN values are produced | ||
| self.assertFalse(torch.any(torch.isnan(image_list.tensors))) |
There was a problem hiding this comment.
@Wadaboa I'm still unable to see this test failing on master without your mitigation. We need to reproduce the problem, so that we are certain that we provide the right fix. Let me know your thoughts.
There was a problem hiding this comment.
I just went through the code that initially produced the issue for me and I realized the following: the mentioned error can only happen when the input image has uint8 dtype (so, in [0, 255] range) and the mean/std parameters are instead in [0, 1]. In this way, the output given by normalization will be all inf (not NaN as shown in the tests).
But in this case image and mean/std would have different ranges. So, is it something that should still be checked? Or maybe I reported an error which should alert the user to check its input and this PR is not needed?
There was a problem hiding this comment.
Thanks for investigating. I agree, there are too many corner-cases to cover for uint8 inputs. This specific piece of code generally expects the input image to be floating type, so if it's not it might be worth throwing an error.
I think what we should do is close this PR to keep things clean and open a new one where you put a type check on top of normalize and throw an exception if the input is not floating point. What you think?
There was a problem hiding this comment.
I agree. I'll close this PR and open a new one for the type check. Thank you.
This PR fixes issue #3228 by casting to float32 the input image of the
normalizemethod in theGeneralizedRCNNTransformclass.Unit tests have been added, to ensure that passing
float32oruint8image/mean/std variables does not lead to failure.