Skip to content

Finetune new model by freezing all layers except FC #1298

@AHarouni

Description

@AHarouni

Is your feature request related to a problem? Please describe.
No
I think an essential usecase is to be able to fine tune a model. Currently the code can load a model then continue training all the layers. I am looking for how can I freeze all layers except the last FC layer and just train a new FC with less number of classes.

Describe the solution you'd like
a simple way to:

  1. pass in the model name/ checkpoint.
  2. Specify the FC layer name to keep training while freezing the rest of the layers

Describe alternatives you've considered
I wrote function below to copy weights and freeze layers for segresnet. it keeps the last layer named conv_0.conv_0. I load this model in my init of my app. However, training doesn't converge so I think something is missing

def pruneModelFCLayer(dst_model, src_model, checkptPath):
    checkpoint = torch.load(checkptPath)
    src_model_state_dict = checkpoint.get("model", checkpoint)
    src_model.load_state_dict(src_model_state_dict , strict=False)

    new_model_state_dic, updated_keys , unchanged_keys = copy_model_state( dst_model , src_model
                                    , exclude_vars="conv_0.conv_0", inplace=False)
    print(f"unchanged keys {unchanged_keys}")
src_model_state_dict['conv_final.2.conv.weight'][j, ...]

    dst_model.load_state_dict(new_model_state_dic)  # , strict=load_strict)

    ### stop gradients for the pretrained weights
    for x in dst_model.named_parameters():
        if x[0] in updated_keys:
            x[1].requires_grad = False

    params = generate_param_groups(network=dst_model,layer_matches=[lambda x: x[0] in updated_keys],
                                   match_types=["filter"],lr_values=[1e-4],include_others=False)

    return dst_model ,params

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions