Skip to content

Consider enhancing get_train_stats() #4563

@holgerroth

Description

@holgerroth

Is your feature request related to a problem? Please describe.
Currently, the Trainer's get_train_stats() function is returning a fixed set of stats

    def get_train_stats(self) -> dict[str, float]:
        return {"total_epochs": self.state.max_epochs, "total_iterations": self.state.epoch_length}

Describe the solution you'd like
Make it configurable true input arguments.

Describe alternatives you've considered
Alternatively, anyone needing train stats could access trainer.state directly. However, having a general get_train_stats() call might have more utilities.

Additional context
Training stats could be useful in contexts such as federated learning.

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

Relationships

None yet

Development

No branches or pull requests

Issue actions