-
Notifications
You must be signed in to change notification settings - Fork 264
Finetune new model by freezing all layers except FC #1298
Copy link
Copy link
Open
Description
Is your feature request related to a problem? Please describe.
No
I think an essential usecase is to be able to fine tune a model. Currently the code can load a model then continue training all the layers. I am looking for how can I freeze all layers except the last FC layer and just train a new FC with less number of classes.
Describe the solution you'd like
a simple way to:
- pass in the model name/ checkpoint.
- Specify the FC layer name to keep training while freezing the rest of the layers
Describe alternatives you've considered
I wrote function below to copy weights and freeze layers for segresnet. it keeps the last layer named conv_0.conv_0. I load this model in my init of my app. However, training doesn't converge so I think something is missing
def pruneModelFCLayer(dst_model, src_model, checkptPath):
checkpoint = torch.load(checkptPath)
src_model_state_dict = checkpoint.get("model", checkpoint)
src_model.load_state_dict(src_model_state_dict , strict=False)
new_model_state_dic, updated_keys , unchanged_keys = copy_model_state( dst_model , src_model
, exclude_vars="conv_0.conv_0", inplace=False)
print(f"unchanged keys {unchanged_keys}")
src_model_state_dict['conv_final.2.conv.weight'][j, ...]
dst_model.load_state_dict(new_model_state_dic) # , strict=load_strict)
### stop gradients for the pretrained weights
for x in dst_model.named_parameters():
if x[0] in updated_keys:
x[1].requires_grad = False
params = generate_param_groups(network=dst_model,layer_matches=[lambda x: x[0] in updated_keys],
match_types=["filter"],lr_values=[1e-4],include_others=False)
return dst_model ,params
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels