Update Reference scripts to support the prototype models#4837
Update Reference scripts to support the prototype models#4837datumbox merged 4 commits intopytorch:mainfrom
Conversation
💊 CI failures summary and remediationsAs of commit 55ddb93 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
datumbox
left a comment
There was a problem hiding this comment.
Some clarifications below:
| else: | ||
| fn = PM.segmentation.__dict__[args.model] | ||
| weights = PM._api.get_weight(fn, args.weights) | ||
| return weights.transforms() |
There was a problem hiding this comment.
If we are in train mode, we always initialize the SegmentationPresetTrain. For validation if the weights are not defined (aka not a prototype model) then use the old preprocessing method for evaluation. Else use the one attached to the weights.
| model = torchvision.models.segmentation.__dict__[args.model]( | ||
| num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained | ||
| ) | ||
| if not args.weights: |
There was a problem hiding this comment.
If the weights are not defined, we use the standard way. Else it's a prototype run which means we will use the prototype model mechanism.
| transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
| else: | ||
| fn = PM.video.__dict__[args.model] | ||
| weights = PM._api.get_weight(fn, args.weights) |
There was a problem hiding this comment.
nit: using a private API here. We probably don't want to advertise private APIs in the references
| cache_path = _get_cache_path(valdir) | ||
|
|
||
| transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
| if not args.weights: |
There was a problem hiding this comment.
nit: pretrained and weights are overlapping and can be confusing. This ideally should be cleaned up in the future
There was a problem hiding this comment.
That's exactly the plan. --pretrained will go away and --weights is going to be the right parameter. Right now we support both temporarily so that we can switch between the two completely different APIs. The --weights acts as a feature switch here.
| if not args.weights: | ||
| transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
| else: | ||
| fn = PM.video.__dict__[args.model] |
There was a problem hiding this comment.
are we providing some sort of registration API to get the models without having to resort to __dict__ manipulations?
There was a problem hiding this comment.
Yeap, that's the plan. There will be a proper registration mechanism, possibly something similar to what was discussed here. There are still pending discussions with other domains, so I didn't want to adopt something before those discussions take place.
* Adding prototype preprocessing on segmentation references. * Adding prototype preprocessing on video references.
Fixes #4671
This PR adds a similar mechanism as in
classificationforsegmentationandvideo. The target is to enable us to test the new model weights API (+ presets) and confirm it returns the same results as the old one. The co-existence of--pretrainedand--weightsis temporary and allows us to test that all models we introduce produce the expected results.The approach is not perfect as it exposes the
prototypestuff in the example reference scripts but the alternative would be to duplicate the reference scripts or keep a separate branch with their modifications which makes the work cumbersome. These will be cleaned up prior to adopting the new API, see #4652 and #4679.cc @datumbox @bjuncek