-
-
Notifications
You must be signed in to change notification settings - Fork 692
Add min_delta_mode parameter to EarlyStopping and update logic for sc… #3516
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
eff89c9
ce70f0a
15d98ec
9105cd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||
| from collections import OrderedDict | ||||||
| from typing import Callable, cast, Mapping, Optional | ||||||
| from typing import Callable, cast, Mapping, Optional, Literal | ||||||
|
|
||||||
| from ignite.base import Serializable | ||||||
| from ignite.engine import Engine | ||||||
|
|
@@ -17,9 +17,15 @@ class EarlyStopping(Serializable): | |||||
| object, and return a score `float`. An improvement is considered if the score is higher. | ||||||
| trainer: Trainer engine to stop the run if no improvement. | ||||||
| min_delta: A minimum increase in the score to qualify as an improvement, | ||||||
| i.e. an increase of less than or equal to `min_delta`, will count as no improvement. | ||||||
| i.e. an increase of less than or equal to the minimum delta threshold (as determined by min_delta and min_delta_mode), will count as no improvement. | ||||||
| cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise, | ||||||
| it defines an increase after the last event. Default value is False. | ||||||
| min_delta_mode: Determine whether `min_delta` is an absolute increase or a relative increase. | ||||||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aaishwarymishra I forgot, we should update the docstring with ignite/ignite/metrics/metric.py Lines 281 to 282 in d2020e4
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add all the changes in docs together in next pr if it sounds good? |
||||||
| In 'abs' mode, the threshold is min_delta, | ||||||
| i.e. an increase of less than or equal to min_delta, will count as no improvement. | ||||||
| In 'rel' mode, the threshold is abs(best_score) * min_delta, | ||||||
| i.e. an increase of less than or equal to abs(best_score) * min_delta, will count as no improvement. | ||||||
| Possible values are "abs" and "rel". Default value is "abs". | ||||||
|
|
||||||
| Examples: | ||||||
| .. code-block:: python | ||||||
|
|
@@ -49,6 +55,7 @@ def __init__( | |||||
| trainer: Engine, | ||||||
| min_delta: float = 0.0, | ||||||
| cumulative_delta: bool = False, | ||||||
| min_delta_mode: Literal["abs", "rel"] = "abs", | ||||||
| ): | ||||||
| if not callable(score_function): | ||||||
| raise TypeError("Argument score_function should be a function.") | ||||||
|
|
@@ -62,6 +69,9 @@ def __init__( | |||||
| if not isinstance(trainer, Engine): | ||||||
| raise TypeError("Argument trainer should be an instance of Engine.") | ||||||
|
|
||||||
| if min_delta_mode not in ("abs", "rel"): | ||||||
| raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.") | ||||||
|
|
||||||
| self.score_function = score_function | ||||||
| self.patience = patience | ||||||
| self.min_delta = min_delta | ||||||
|
|
@@ -70,13 +80,20 @@ def __init__( | |||||
| self.counter = 0 | ||||||
| self.best_score: Optional[float] = None | ||||||
| self.logger = setup_logger(__name__ + "." + self.__class__.__name__) | ||||||
| self.min_delta_mode = min_delta_mode | ||||||
|
|
||||||
| def __call__(self, engine: Engine) -> None: | ||||||
| score = self.score_function(engine) | ||||||
|
|
||||||
| if self.best_score is None: | ||||||
| self.best_score = score | ||||||
| elif score <= self.best_score + self.min_delta: | ||||||
| return | ||||||
| upper_bound = ( | ||||||
| self.best_score + self.min_delta | ||||||
| if self.min_delta_mode == "abs" | ||||||
| else self.best_score + abs(self.best_score) * self.min_delta | ||||||
| ) | ||||||
| if score <= upper_bound: | ||||||
| if not self.cumulative_delta and score > self.best_score: | ||||||
| self.best_score = score | ||||||
| self.counter += 1 | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.