Multi-pretrained weight support - FasterRCNN ResNet50#4613
Multi-pretrained weight support - FasterRCNN ResNet50#4613datumbox merged 6 commits intopytorch:mainfrom
Conversation
datumbox
left a comment
There was a problem hiding this comment.
Some clarification comments below:
| returned_layers=None, | ||
| extra_blocks=None, | ||
| ): | ||
| backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) |
There was a problem hiding this comment.
Unfortunately I'm forced to copy the whole function just to change the pretrained to weights param. I refactored to minimize copy-pasted code.
|
|
||
|
|
||
| # Allows handling of both PIL and Tensor images | ||
| class ConvertImageDtype(nn.Module): |
There was a problem hiding this comment.
Removed the standalone transform to avoid introducing a new class here.
| import warnings | ||
| from typing import Any, Optional | ||
|
|
||
| from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers |
There was a problem hiding this comment.
Inherit as much as possible. The changes below will be moved on the existing files once we move to torchvision.
|
|
||
| def fasterrcnn_resnet50_fpn( | ||
| weights: Optional[FasterRCNNResNet50FPNWeights] = None, | ||
| weights_backbone: Optional[ResNet50Weights] = ResNet50Weights.ImageNet1K_RefV1, # TODO: Should we default to None? |
There was a problem hiding this comment.
This default value is to align with the old logic, NEVERTHELESS it's not necessary for BC (due to the way I handled the pretrained param below).
Personally I don't like that we have hardcoded an OLD set of weights here. If we set the default value to None and force the user to choose, we will eliminate future BC considerations if these weights become too old and not-optimal.
This can be addressed on a follow up PR.
cc @fmassa
There was a problem hiding this comment.
I'm going to set this to None, happy to review if someone has a strong opinion.
prabhat00155
left a comment
There was a problem hiding this comment.
Thanks @datumbox! Few minor comments.
|
|
||
| if weights is not None: | ||
| weights_backbone = None | ||
| num_classes = len(weights.meta["categories"]) |
There was a problem hiding this comment.
We should probably raise an error / warning if the user modifies the num_classes and passes a weights argument. Otherwise they might silently think that we are doing magic inside
There was a problem hiding this comment.
Sounds good. I'll add this check to resnet as well.
There was a problem hiding this comment.
I thought about this and it's a bit problematic. The num_classes parameter has a default value in all of our model builders. So to see i it was modified, we need to see if the default value was changed which can lead to messy code. An alternative approach could be to throw a warning if the num_classes != len(weights.meta["categories"]) but still overwrite it to make the life of users easier.
Because it's not clear how this should be handled, I'm going to merge the PR to unblock the work but I'm happy to discuss the policy here and update everywhere in a follow up PR.
|
Hey @datumbox! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
* Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights.
Summary: * Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights. Reviewed By: NicolasHug Differential Revision: D31758312 fbshipit-source-id: 714a714d897bb4b4d9da1298ad5e2606998898b9
* Adding FasterRCNN ResNet50. * Refactoring to remove duplicate code. * Adding typing info. * Setting weights_backbone=None as default value. * Overwrite eps only for specific weights.
Resolves #4671
Example usage:
cc @datumbox @pmeier @bjuncek