Skip to content

ModelCheckpoint score_function confusing to use #677

@felix-ht

Description

@felix-ht

I use create_supervised_trainer in combination with ModelCheckpoint.

checkpoint = ModelCheckpoint(
    f"./models/{start_time}",
    "resnet_50",
    score_name="loss",
    score_function=lambda engine: 1 - engine.state.output, # output contains the minibatch output
    n_saved=10,
    create_dir=True,
)

With supervised trainer in not straight forward to add a metric. So users only have access to the state.output, which only contains the minibatch. This causes models to be saved that based on the model performance on the last minibatch, not the whole epoch.

It might be a good idea to add a warning to the description of Modelcheckpoint that output contains the minibach output only. And/or add an example that shows how to use it in combination with a metric.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions