Conversation
|
It seems that the |
datumbox
left a comment
There was a problem hiding this comment.
Adding some highlights to assist review:
|
|
||
|
|
||
| try: | ||
| from torchvision.prototype import models as PM |
There was a problem hiding this comment.
Try to import the prototype models but without failing.
| print("Loading dataset_test from {}".format(cache_path)) | ||
| dataset_test, _ = torch.load(cache_path) | ||
| else: | ||
| if not args.weights: |
There was a problem hiding this comment.
Which preprocessing we will use depends on whether weights are defined.
| else: | ||
| fn = PM.__dict__[args.model] | ||
| weights = PM._api.get_weight(fn, args.weights) | ||
| preprocessing = weights.transforms() |
There was a problem hiding this comment.
Having a definition of the weights means we will be accessing the prototype models. Those have the preprocessing attached to the weights, so we fetch them and construct the preprocessing class.
| return super().__getattr__(name) | ||
|
|
||
|
|
||
| def get_weight(fn: Callable, weight_name: str) -> Weights: |
There was a problem hiding this comment.
For now I consider it a private method. We will eventually need to make it public because getting the enum class from a string is useful but it's unclear whether we should do it by passing the model_builder and then weight_name or construct it via the fully qualified name.
There was a problem hiding this comment.
Sorry I only got a chance to looks at it now.
Relying on the model_builder's annotation seems like a pretty involved way of retrieving the weights.
Should we go simple here and just register all the weights in some sort of private _AVAILABLE_WEIGHTS dict? get_weight() would then just be a query into this private dict
(This is my only comment, the rest of the PR looks great!)
There was a problem hiding this comment.
@NicolasHug Thanks for looking at it. FYI I merged after Prabhat's review so that we pass this to the FBsync but I plan to make changes on follow up PRs.
I agree that this is involved and that's why I haven't exposed it as public. I've added an entry at #4652 to review the mechanism and more specifically sync with you on making it Torchhub friendly. One option as you said is to have a similar registration mechanism as proposed here to keep track of method/weight combos and flag also the "best/latest" weights. I have on purpose omitted all the versioning parts of the original RFC to allow for discussions across Audio and Text to continue and see if we can adopt a common solution. But I think they are currently looking into moving towards a different direction that has no model builders, so we might be able to bring this feature sooner.
Summary: * Update model checkpoint for resnet50. * Add get_weight method to retrieve weights from name. * Update the references to support prototype weights. * Fixing mypy typing. * Switching to a python 3.6 supported equivalent. * Add unit-test. * Add optional num_classes. Reviewed By: NicolasHug Differential Revision: D31916330 fbshipit-source-id: 2ac0f9202f62a78078b0917e6730d7fc0925acdf
* Update model checkpoint for resnet50. * Add get_weight method to retrieve weights from name. * Update the references to support prototype weights. * Fixing mypy typing. * Switching to a python 3.6 supported equivalent. * Add unit-test. * Add optional num_classes.
Related to #3995
This PR does 2 things:
Concerning the new model weights, it was trained using the Batteries Included primitives and achieves the following accuracy:
The linked issue provides high-level details on the recipe but I'll also follow up with a blogpost on how it was trained.
cc @datumbox @vfdev-5 @pmeier @bjuncek