Skip to content

v0.5.3 (pr #3440) SAVED_CHECKPOINT event missing registration in event_to_attr, causes multiple Checkpoints fail. #3502

@lyhyl

Description

@lyhyl

🐛 Bug description

    to_save = {'model': model, 'optimizer': optim, 'lr_scheduler': sched, 'trainer': trainer}
    # This one will work
    best_ckpt = Checkpoint(to_save,
                           save_handler=DiskSaver(f"{log_dir}/ckpt", create_dir=True),
                           filename_prefix="R2_best",
                           score_name="R2",
                           global_step_transform=global_step_from_engine(trainer))
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, best_ckpt)
    # This one will fail
    last_ckpt = Checkpoint(to_save,
                           save_handler=DiskSaver(f"{log_dir}/ckpt", create_dir=True),
                           filename_prefix="last",
                           global_step_transform=global_step_from_engine(trainer))
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, last_ckpt)

    @trainer.on(Events.EPOCH_COMPLETED)
    def trainer_epoch_completed(engine: Engine):
        evaluator.run(train_dl)

This code works in Ignite 0.5.2.
This code snippet will fail in the second checkpoint, because:

global_step = self.global_step_transform(engine, engine.last_event_name)

Checkpoint.__call__ will call global_step_transform with engine.last_event_name.
In the first checkpoint, it is fine with engine.last_event_name == Events.EPOCH_COMPLETED.
In the second checkpoint, it fails with engine.last_event_name == CheckpointEvents.SAVED_CHECKPOINT, because of #3440, and this event is not registered (#1934) in:
event_to_attr: Dict[Union[str, "Events", "CallableEventWithFilter"], str] = {
Events.GET_BATCH_STARTED: "iteration",
Events.GET_BATCH_COMPLETED: "iteration",
Events.ITERATION_STARTED: "iteration",
Events.ITERATION_COMPLETED: "iteration",
Events.EPOCH_STARTED: "epoch",
Events.EPOCH_COMPLETED: "epoch",
Events.STARTED: "epoch",
Events.COMPLETED: "epoch",
}

cause following failure:
def get_event_attrib_value(self, event_name: Union[str, Events, CallableEventWithFilter]) -> int:
"""Get the value of Event attribute with given `event_name`."""
if event_name not in State.event_to_attr:
raise RuntimeError(f"Unknown event name '{event_name}'")
return getattr(self, State.event_to_attr[event_name])

A checked work-around:

from ignite.engine import State
State.event_to_attr[Checkpoint.SAVED_CHECKPOINT] = "epoch"

Environment

  • PyTorch Version (e.g., 1.4): 2.8.0
  • Ignite Version (e.g., 0.3.0): 0.5.3
  • OS (e.g., Linux): Ubuntu 24
  • How you installed Ignite (conda, pip, source): uv
  • Python version: 3.12
  • Any other relevant information: NA

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions