Pytorchのモデル学習の中断と再開の処理を実装したメモ【備忘録】

お疲れ様です。

今回はPytorchで学習の途中再開をするためのコードのメモです。 長期間モデル学習を実行する際に予期せぬトラブルで処理が止まってしまった場合などにも使えると思います。

ソースコード

以前作成した画像分類のプロジェクトに実装しています。 github.com

実装

  • モデルの保存時
    下記modules/trainer.pyのTrainerクラスのモデル保存用メソッドです。 modelのパラメータに加えてoptimizerのパラメータを保存するのがポイントになっています。
    ちなみに下記は最新epochのモデルを保存するためのメソッドですが、これを毎epoch行うようにしています。 こうすることで処理を止めたor止まったタイミングから再開することができます。
def save_weight_latest(
    self,
    epoch: int
) -> None:
    """モデルの重みを保存
    """        
    # 最終epochのモデル
    model_name = "model_latest.pth"
    checkpoint = {
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, self.output_path.joinpath(model_name))
    print(f"model saved: {model_name}")
  • 途中再開時
    学習の途中再開用にmain_train_resume.pyを作成しています。 torch.loadで読み込みmodelとoptimizerそれぞれで読み込みをする形になります。
# 途中保存した重みの読み込み
checkpoint_path = output_path.joinpath("model_latest.pth")
checkpoint = torch.load(checkpoint_path, map_location=device)

# モデルの定義
model, _ = get_model_train(
    model_name=model_name, 
    num_classes=train_dataset.num_classes, 
    use_pretrained=use_pretrained
)
# 途中保存の重みに更新
model.load_state_dict(checkpoint["model_state_dict"])
params = model.parameters()

# optimizerの定義
optimizer = RAdamScheduleFree(params, lr=lr)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

実行するとこんな感じになります。 以下、「①中断せずに20epoch回した結果」と「②10epochで中断して再開した結果」を並べてみました。
似たような曲線を描いているので問題なく途中再開できていそうです。


  • 1


  • 2

参考

参考にさせていただいた記事を載せておきます。
qiita.com

再現性も保持したい場合は乱数の状態も一緒に保存し、再開の際に読み込むことで実現できるようです。 今回作成のコードでは実装していませんが。 qiita.com