Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.
This repository was archived by the owner on Apr 8, 2025. It is now read-only.

Bug in example doc_classification_with_earlystopping.py. #422

@PhilipMay

Description

@PhilipMay

There is a bug executing doc_classification_with_earlystopping.py on colab. See full notebook here: https://gist.github.com/PhilipMay/8b042f713603e68deb5519fb7776d128

[...]
<ipython-input-3-66a7ddd46d8d> in doc_classification_with_earlystopping()
    128 
    129     # 7. Let it grow
--> 130     trainer.train()
    131 
    132     # 8. Hooray! You have a model.

/content/FARM/farm/train.py in train(self)
    352             logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir))
    353             lm_name = self.model.language_model.name
--> 354             model = AdaptiveModel.load(self.early_stopping.save_dir, self.device, lm_name=lm_name)
    355             model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
    356 

/content/FARM/farm/modeling/adaptive_model.py in load(cls, load_dir, device, strict, lm_name, processor)
    321         ph_output_type = []
    322         for config_file in ph_config_files:
--> 323             head = PredictionHead.load(config_file, strict=strict)
    324             prediction_heads.append(head)
    325             ph_output_type.append(head.ph_output_type)

/content/FARM/farm/modeling/prediction_head.py in load(cls, config_file, strict, load_weights)
    116             model_file = cls._get_model_file(config_file=config_file)
    117             logger.info("Loading prediction head from {}".format(model_file))
--> 118             prediction_head.load_state_dict(torch.load(model_file, map_location=torch.device("cpu")), strict=strict)
    119         return prediction_head
    120 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    845         if len(error_msgs) > 0:
    846             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847                                self.__class__.__name__, "\n\t".join(error_msgs)))
    848         return _IncompatibleKeys(missing_keys, unexpected_keys)
    849 

RuntimeError: Error(s) in loading state_dict for TextClassificationHead:
	Unexpected key(s) in state_dict: "loss_fct.weight". 

If you could provide a workaround to avoid the exeption I would be very happy. As always I am willing to start a PR but will need some help with debugging.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions