This repository was archived by the owner on Apr 8, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
This repository was archived by the owner on Apr 8, 2025. It is now read-only.
Bug in example doc_classification_with_earlystopping.py. #422
Copy link
Copy link
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
There is a bug executing doc_classification_with_earlystopping.py on colab. See full notebook here: https://gist.github.com/PhilipMay/8b042f713603e68deb5519fb7776d128
[...]
<ipython-input-3-66a7ddd46d8d> in doc_classification_with_earlystopping()
128
129 # 7. Let it grow
--> 130 trainer.train()
131
132 # 8. Hooray! You have a model.
/content/FARM/farm/train.py in train(self)
352 logger.info("Restoring best model so far from {}".format(self.early_stopping.save_dir))
353 lm_name = self.model.language_model.name
--> 354 model = AdaptiveModel.load(self.early_stopping.save_dir, self.device, lm_name=lm_name)
355 model.connect_heads_with_processor(self.data_silo.processor.tasks, require_labels=True)
356
/content/FARM/farm/modeling/adaptive_model.py in load(cls, load_dir, device, strict, lm_name, processor)
321 ph_output_type = []
322 for config_file in ph_config_files:
--> 323 head = PredictionHead.load(config_file, strict=strict)
324 prediction_heads.append(head)
325 ph_output_type.append(head.ph_output_type)
/content/FARM/farm/modeling/prediction_head.py in load(cls, config_file, strict, load_weights)
116 model_file = cls._get_model_file(config_file=config_file)
117 logger.info("Loading prediction head from {}".format(model_file))
--> 118 prediction_head.load_state_dict(torch.load(model_file, map_location=torch.device("cpu")), strict=strict)
119 return prediction_head
120
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
845 if len(error_msgs) > 0:
846 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847 self.__class__.__name__, "\n\t".join(error_msgs)))
848 return _IncompatibleKeys(missing_keys, unexpected_keys)
849
RuntimeError: Error(s) in loading state_dict for TextClassificationHead:
Unexpected key(s) in state_dict: "loss_fct.weight".
If you could provide a workaround to avoid the exeption I would be very happy. As always I am willing to start a PR but will need some help with debugging.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working