Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Callable, cast, Mapping, Optional
from typing import Callable, cast, Mapping, Optional, Literal

from ignite.base import Serializable
from ignite.engine import Engine
Expand All @@ -17,9 +17,15 @@ class EarlyStopping(Serializable):
object, and return a score `float`. An improvement is considered if the score is higher.
trainer: Trainer engine to stop the run if no improvement.
min_delta: A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
i.e. an increase of less than or equal to the minimum delta threshold (as determined by min_delta and min_delta_mode), will count as no improvement.
cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event. Default value is False.
min_delta_mode: Determine whether `min_delta` is an absolute increase or a relative increase.
Copy link
Copy Markdown
Collaborator

@vfdev-5 vfdev-5 Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aaishwarymishra I forgot, we should update the docstring with .. versionchanged:: directlve for the new arg. For example:

.. versionchanged:: 0.4.5
``y_pred`` and ``y`` can be torch tensors or list of tensors/numbers

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add all the changes in docs together in next pr if it sounds good?

In 'abs' mode, the threshold is min_delta,
i.e. an increase of less than or equal to min_delta, will count as no improvement.
In 'rel' mode, the threshold is abs(best_score) * min_delta,
i.e. an increase of less than or equal to abs(best_score) * min_delta, will count as no improvement.
Possible values are "abs" and "rel". Default value is "abs".

Examples:
.. code-block:: python
Expand Down Expand Up @@ -49,6 +55,7 @@ def __init__(
trainer: Engine,
min_delta: float = 0.0,
cumulative_delta: bool = False,
min_delta_mode: Literal["abs", "rel"] = "abs",
):
if not callable(score_function):
raise TypeError("Argument score_function should be a function.")
Expand All @@ -62,6 +69,9 @@ def __init__(
if not isinstance(trainer, Engine):
raise TypeError("Argument trainer should be an instance of Engine.")

if min_delta_mode not in ("abs", "rel"):
raise ValueError("Argument min_delta_mode should be either 'abs' or 'rel'.")

self.score_function = score_function
self.patience = patience
self.min_delta = min_delta
Expand All @@ -70,13 +80,20 @@ def __init__(
self.counter = 0
self.best_score: Optional[float] = None
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
self.min_delta_mode = min_delta_mode

def __call__(self, engine: Engine) -> None:
score = self.score_function(engine)

if self.best_score is None:
self.best_score = score
elif score <= self.best_score + self.min_delta:
return
upper_bound = (
self.best_score + self.min_delta
if self.min_delta_mode == "abs"
else self.best_score + abs(self.best_score) * self.min_delta
)
if score <= upper_bound:
if not self.cumulative_delta and score > self.best_score:
self.best_score = score
self.counter += 1
Expand Down
56 changes: 56 additions & 0 deletions tests/ignite/handlers/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def test_args_validation():
with pytest.raises(TypeError, match=r"Argument trainer should be an instance of Engine."):
EarlyStopping(patience=2, score_function=lambda engine: 0, trainer=None)

with pytest.raises(ValueError, match=r"Argument min_delta_mode should be either 'abs' or 'rel'."):
EarlyStopping(patience=2, min_delta_mode="invalid_mode", score_function=lambda engine: 0, trainer=trainer)


def test_simple_early_stopping():
scores = iter([1.0, 0.8, 0.88])
Expand Down Expand Up @@ -71,6 +74,34 @@ def score_function(engine):
assert trainer.should_terminate


def test_state_dict_with_mode():
scores = iter([1.0, 2.0, 2.1, 2.2])

def score_function(engine):
return next(scores)

trainer = Engine(do_nothing_update_fn)

# Use "rel" mode
h = EarlyStopping(patience=2, score_function=score_function, trainer=trainer, min_delta=0.1, min_delta_mode="rel")
h(None) # best_score=1.0
h(None) # score=2.0 (improvement)

state = h.state_dict()

# New handler with "rel" mode
h2 = EarlyStopping(patience=2, score_function=score_function, trainer=trainer, min_delta=0.1, min_delta_mode="rel")
h2.load_state_dict(state)

assert h2.min_delta_mode == "rel"
h2(None) # score=2.1 (no improvement: 2.1 <= 2.0 * 1.1 = 2.2)
assert h2.counter == 1
assert not trainer.should_terminate
h2(None) # score=2.2 (no improvement: 2.2 <= 2.2)
assert h2.counter == 2
assert trainer.should_terminate


def test_early_stopping_on_delta():
scores = iter([1.0, 2.0, 2.01, 3.0, 3.01, 3.02])

Expand All @@ -93,6 +124,31 @@ def test_early_stopping_on_delta():
assert trainer.should_terminate


def test_early_stopping_on_rel_delta():
scores = iter([1.0, 2.0, 2.1, 3.0, 3.2, 3.25])

trainer = Engine(do_nothing_update_fn)

# upper_bound = best_score * (1 + min_delta)
h = EarlyStopping(
patience=2, min_delta=0.1, min_delta_mode="rel", score_function=lambda _: next(scores), trainer=trainer
)

assert not trainer.should_terminate
h(None) # best_score = 1.0; counter == 0
assert not trainer.should_terminate
h(None) # score = 2.0; upper_bound = 1.0 * (1.1) = 1.1; 2.0 > 1.1; best_score = 2.0; counter == 0
assert not trainer.should_terminate
h(None) # score = 2.1; upper_bound = 2.0 * (1.1) = 2.2; 2.1 <= 2.2; counter == 1
assert not trainer.should_terminate
h(None) # score = 3.0; upper_bound = 2.0 * (1.1) = 2.2; 3.0 > 2.2; best_score = 3.0; counter == 0
assert not trainer.should_terminate
h(None) # score = 3.2; upper_bound = 3.0 * (1.1) = 3.3; 3.2 <= 3.3; counter == 1
assert not trainer.should_terminate
h(None) # score = 3.25; upper_bound = 3.0 * (1.1) = 3.3; 3.25 <= 3.3; counter == 2
assert trainer.should_terminate


def test_early_stopping_on_last_event_delta():
scores = iter([0.0, 0.3, 0.6])

Expand Down
Loading