Skip to content

Commit 88fb0b1

Browse files
authored
Added callable options for iteration_log and epoch_log in TensorBoard and MLFlow (#5976)
Follow-up PR on extending `iteration_log`/`epoch_log` functionality ### Description > Could you please help also enable this feature for other very similar handlers? Like: https://github.com/Project-MONAI/MONAI/blob/dev/monai/handlers/tensorboard_handlers.py#L94 And: https://github.com/Project-MONAI/MONAI/blob/dev/monai/handlers/mlflow_handler.py#L101 - #5965 (review) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder.
1 parent 2a8c8cd commit 88fb0b1

File tree

6 files changed

+229
-16
lines changed

6 files changed

+229
-16
lines changed

monai/handlers/mlflow_handler.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ class MLFlowHandler:
6060
to log data to a directory. The URI defaults to path `mlruns`.
6161
for more details: https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.set_tracking_uri.
6262
iteration_log: whether to log data to MLFlow when iteration completed, default to `True`.
63+
``iteration_log`` can be also a function and it will be interpreted as an event filter
64+
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
65+
Event filter function accepts as input engine and event value (iteration) and should return True/False.
6366
epoch_log: whether to log data to MLFlow when epoch completed, default to `True`.
67+
``epoch_log`` can be also a function and it will be interpreted as an event filter.
68+
See ``iteration_log`` argument for more details.
6469
epoch_logger: customized callable logger for epoch level logging with MLFlow.
6570
Must accept parameter "engine", use default logger if None.
6671
iteration_logger: customized callable logger for iteration level logging with MLFlow.
@@ -98,8 +103,8 @@ class MLFlowHandler:
98103
def __init__(
99104
self,
100105
tracking_uri: str | None = None,
101-
iteration_log: bool = True,
102-
epoch_log: bool = True,
106+
iteration_log: bool | Callable[[Engine, int], bool] = True,
107+
epoch_log: bool | Callable[[Engine, int], bool] = True,
103108
epoch_logger: Callable[[Engine], Any] | None = None,
104109
iteration_logger: Callable[[Engine], Any] | None = None,
105110
output_transform: Callable = lambda x: x[0],
@@ -159,9 +164,15 @@ def attach(self, engine: Engine) -> None:
159164
if not engine.has_event_handler(self.start, Events.STARTED):
160165
engine.add_event_handler(Events.STARTED, self.start)
161166
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
162-
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
167+
event = Events.ITERATION_COMPLETED
168+
if callable(self.iteration_log): # substitute event with new one using filter callable
169+
event = event(event_filter=self.iteration_log)
170+
engine.add_event_handler(event, self.iteration_completed)
163171
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
164-
engine.add_event_handler(Events.EPOCH_COMPLETED, self.epoch_completed)
172+
event = Events.EPOCH_COMPLETED
173+
if callable(self.epoch_log): # substitute event with new one using filter callable
174+
event = event(event_filter=self.epoch_log)
175+
engine.add_event_handler(event, self.epoch_completed)
165176
if not engine.has_event_handler(self.complete, Events.COMPLETED):
166177
engine.add_event_handler(Events.COMPLETED, self.complete)
167178
if self.close_on_complete and (not engine.has_event_handler(self.close, Events.COMPLETED)):

monai/handlers/tensorboard_handlers.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020

2121
from monai.config import IgniteInfo
22-
from monai.utils import is_scalar, min_version, optional_import
22+
from monai.utils import deprecated_arg, is_scalar, min_version, optional_import
2323
from monai.visualize import plot_2d_or_3d_image
2424

2525
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
@@ -87,12 +87,14 @@ class TensorBoardStatsHandler(TensorBoardHandler):
8787
8888
"""
8989

90+
@deprecated_arg("epoch_interval", since="1.1", removed="1.3")
91+
@deprecated_arg("iteration_interval", since="1.1", removed="1.3")
9092
def __init__(
9193
self,
9294
summary_writer: SummaryWriter | SummaryWriterX | None = None,
9395
log_dir: str = "./runs",
94-
iteration_log: bool = True,
95-
epoch_log: bool = True,
96+
iteration_log: bool | Callable[[Engine, int], bool] = True,
97+
epoch_log: bool | Callable[[Engine, int], bool] = True,
9698
epoch_event_writer: Callable[[Engine, Any], Any] | None = None,
9799
epoch_interval: int = 1,
98100
iteration_event_writer: Callable[[Engine, Any], Any] | None = None,
@@ -108,13 +110,20 @@ def __init__(
108110
default to create a new TensorBoard writer.
109111
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
110112
iteration_log: whether to write data to TensorBoard when iteration completed, default to `True`.
113+
``iteration_log`` can be also a function and it will be interpreted as an event filter
114+
(see https://pytorch.org/ignite/generated/ignite.engine.events.Events.html for details).
115+
Event filter function accepts as input engine and event value (iteration) and should return True/False.
111116
epoch_log: whether to write data to TensorBoard when epoch completed, default to `True`.
117+
``epoch_log`` can be also a function and it will be interpreted as an event filter.
118+
See ``iteration_log`` argument for more details.
112119
epoch_event_writer: customized callable TensorBoard writer for epoch level.
113120
Must accept parameter "engine" and "summary_writer", use default event writer if None.
114121
epoch_interval: the epoch interval at which the epoch_event_writer is called. Defaults to 1.
122+
``epoch_interval`` must be 1 if ``epoch_log`` is callable.
115123
iteration_event_writer: customized callable TensorBoard writer for iteration level.
116124
Must accept parameter "engine" and "summary_writer", use default event writer if None.
117125
iteration_interval: the iteration interval at which the iteration_event_writer is called. Defaults to 1.
126+
``iteration_interval`` must be 1 if ``iteration_log`` is callable.
118127
output_transform: a callable that is used to transform the
119128
``ignite.engine.state.output`` into a scalar to plot, or a dictionary of {key: scalar}.
120129
In the latter case, the output string will be formatted as key: value.
@@ -131,6 +140,12 @@ def __init__(
131140
when epoch completed.
132141
tag_name: when iteration output is a scalar, tag_name is used to plot, defaults to ``'Loss'``.
133142
"""
143+
if callable(iteration_log) and iteration_interval > 1:
144+
raise ValueError("If iteration_log is callable, then iteration_interval should be 1")
145+
146+
if callable(epoch_log) and epoch_interval > 1:
147+
raise ValueError("If epoch_log is callable, then epoch_interval should be 1")
148+
134149
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
135150
self.iteration_log = iteration_log
136151
self.epoch_log = epoch_log
@@ -152,11 +167,19 @@ def attach(self, engine: Engine) -> None:
152167
153168
"""
154169
if self.iteration_log and not engine.has_event_handler(self.iteration_completed, Events.ITERATION_COMPLETED):
155-
engine.add_event_handler(
156-
Events.ITERATION_COMPLETED(every=self.iteration_interval), self.iteration_completed
157-
)
170+
event = Events.ITERATION_COMPLETED
171+
if callable(self.iteration_log): # substitute event with new one using filter callable
172+
event = event(event_filter=self.iteration_log)
173+
elif self.iteration_interval > 1:
174+
event = event(every=self.iteration_interval)
175+
engine.add_event_handler(event, self.iteration_completed)
158176
if self.epoch_log and not engine.has_event_handler(self.epoch_completed, Events.EPOCH_COMPLETED):
159-
engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.epoch_interval), self.epoch_completed)
177+
event = Events.EPOCH_COMPLETED
178+
if callable(self.epoch_log): # substitute event with new one using filter callable
179+
event = event(event_filter=self.epoch_log)
180+
elif self.epoch_log > 1:
181+
event = event(every=self.epoch_interval)
182+
engine.add_event_handler(event, self.epoch_completed)
160183

161184
def epoch_completed(self, engine: Engine) -> None:
162185
"""

monai/utils/deprecate_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def deprecated_arg(
174174
else:
175175
# compare the numbers
176176
is_deprecated = since is not None and version_leq(since, version_val)
177-
is_removed = removed is not None and version_leq(removed, version_val)
177+
is_removed = removed is not None and version_val != f"{sys.maxsize}" and version_leq(removed, version_val)
178178

179179
def _decorator(func):
180180
argname = f"{func.__module__} {func.__qualname__}:{name}"
@@ -284,7 +284,7 @@ def deprecated_arg_default(
284284
else:
285285
# compare the numbers
286286
is_deprecated = since is not None and version_leq(since, version_val)
287-
is_replaced = replaced is not None and version_leq(replaced, version_val)
287+
is_replaced = replaced is not None and version_val != f"{sys.maxsize}" and version_leq(replaced, version_val)
288288

289289
def _decorator(func):
290290
argname = f"{func.__module__} {func.__qualname__}:{name}"

tests/test_deprecated.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def test_arg_except2_unknown(self):
234234
def afoo4(a, b=None):
235235
pass
236236

237-
self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
237+
afoo4(1, b=2)
238238

239239
def test_arg_except3_unknown(self):
240240
"""
@@ -246,8 +246,8 @@ def test_arg_except3_unknown(self):
246246
def afoo4(a, b=None, **kwargs):
247247
pass
248248

249-
self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2))
250-
self.assertRaises(DeprecatedError, lambda: afoo4(1, b=2, c=3))
249+
afoo4(1, b=2)
250+
afoo4(1, b=2, c=3)
251251

252252
def test_replacement_arg(self):
253253
"""

tests/test_handler_mlflow.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,25 @@
1717
import tempfile
1818
import unittest
1919
from concurrent.futures import ThreadPoolExecutor
20+
from unittest.mock import MagicMock
2021

2122
import numpy as np
2223
from ignite.engine import Engine, Events
24+
from parameterized import parameterized
2325

2426
from monai.handlers import MLFlowHandler
2527
from monai.utils import path_to_uri
2628

2729

30+
def get_event_filter(e):
31+
def event_filter(_, event):
32+
if event in e:
33+
return True
34+
return False
35+
36+
return event_filter
37+
38+
2839
def dummy_train(tracking_folder):
2940
tempdir = tempfile.mkdtemp()
3041

@@ -95,6 +106,85 @@ def _update_metric(engine):
95106
# check logging output
96107
self.assertTrue(len(glob.glob(test_path)) > 0)
97108

109+
@parameterized.expand([[True], [get_event_filter([1, 2])]])
110+
def test_metrics_track_mock(self, epoch_log):
111+
experiment_param = {"backbone": "efficientnet_b0"}
112+
with tempfile.TemporaryDirectory() as tempdir:
113+
# set up engine
114+
def _train_func(engine, batch):
115+
return [batch + 1.0]
116+
117+
engine = Engine(_train_func)
118+
119+
# set up dummy metric
120+
@engine.on(Events.EPOCH_COMPLETED)
121+
def _update_metric(engine):
122+
current_metric = engine.state.metrics.get("acc", 0.1)
123+
engine.state.metrics["acc"] = current_metric + 0.1
124+
engine.state.test = current_metric
125+
126+
# set up testing handler
127+
test_path = os.path.join(tempdir, "mlflow_test")
128+
handler = MLFlowHandler(
129+
iteration_log=False,
130+
epoch_log=epoch_log,
131+
tracking_uri=path_to_uri(test_path),
132+
state_attributes=["test"],
133+
experiment_param=experiment_param,
134+
close_on_complete=True,
135+
)
136+
handler._default_epoch_log = MagicMock()
137+
handler.attach(engine)
138+
139+
max_epochs = 4
140+
engine.run(range(3), max_epochs=max_epochs)
141+
handler.close()
142+
# check logging output
143+
if epoch_log is True:
144+
self.assertEqual(handler._default_epoch_log.call_count, max_epochs)
145+
else:
146+
self.assertEqual(handler._default_epoch_log.call_count, 2) # 2 = len([1, 2]) from event_filter
147+
148+
@parameterized.expand([[True], [get_event_filter([1, 3])]])
149+
def test_metrics_track_iters_mock(self, iteration_log):
150+
experiment_param = {"backbone": "efficientnet_b0"}
151+
with tempfile.TemporaryDirectory() as tempdir:
152+
# set up engine
153+
def _train_func(engine, batch):
154+
return [batch + 1.0]
155+
156+
engine = Engine(_train_func)
157+
158+
# set up dummy metric
159+
@engine.on(Events.EPOCH_COMPLETED)
160+
def _update_metric(engine):
161+
current_metric = engine.state.metrics.get("acc", 0.1)
162+
engine.state.metrics["acc"] = current_metric + 0.1
163+
engine.state.test = current_metric
164+
165+
# set up testing handler
166+
test_path = os.path.join(tempdir, "mlflow_test")
167+
handler = MLFlowHandler(
168+
iteration_log=iteration_log,
169+
epoch_log=False,
170+
tracking_uri=path_to_uri(test_path),
171+
state_attributes=["test"],
172+
experiment_param=experiment_param,
173+
close_on_complete=True,
174+
)
175+
handler._default_iteration_log = MagicMock()
176+
handler.attach(engine)
177+
178+
num_iters = 3
179+
max_epochs = 2
180+
engine.run(range(num_iters), max_epochs=max_epochs)
181+
handler.close()
182+
# check logging output
183+
if iteration_log is True:
184+
self.assertEqual(handler._default_iteration_log.call_count, num_iters * max_epochs)
185+
else:
186+
self.assertEqual(handler._default_iteration_log.call_count, 2) # 2 = len([1, 3]) from event_filter
187+
98188
def test_multi_thread(self):
99189
test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"]
100190
with ThreadPoolExecutor(2, "Training") as executor:

tests/test_handler_tb_stats.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,36 @@
1414
import glob
1515
import tempfile
1616
import unittest
17+
from unittest.mock import MagicMock
1718

1819
from ignite.engine import Engine, Events
20+
from parameterized import parameterized
1921

2022
from monai.handlers import TensorBoardStatsHandler
2123
from monai.utils import optional_import
2224

2325
SummaryWriter, has_tb = optional_import("torch.utils.tensorboard", name="SummaryWriter")
2426

2527

28+
def get_event_filter(e):
29+
def event_filter(_, event):
30+
if event in e:
31+
return True
32+
return False
33+
34+
return event_filter
35+
36+
2637
@unittest.skipUnless(has_tb, "Requires SummaryWriter installation")
2738
class TestHandlerTBStats(unittest.TestCase):
39+
def test_args_validation(self):
40+
with self.assertWarns(FutureWarning):
41+
with self.assertRaisesRegex(ValueError, expected_regex="iteration_interval should be 1"):
42+
TensorBoardStatsHandler(log_dir=".", iteration_log=get_event_filter([1, 2]), iteration_interval=2)
43+
44+
with self.assertRaisesRegex(ValueError, expected_regex="epoch_interval should be 1"):
45+
TensorBoardStatsHandler(log_dir=".", epoch_log=get_event_filter([1, 2]), epoch_interval=2)
46+
2847
def test_metrics_print(self):
2948
with tempfile.TemporaryDirectory() as tempdir:
3049
# set up engine
@@ -47,6 +66,35 @@ def _update_metric(engine):
4766
# check logging output
4867
self.assertTrue(len(glob.glob(tempdir)) > 0)
4968

69+
@parameterized.expand([[True], [get_event_filter([1, 2])]])
70+
def test_metrics_print_mock(self, epoch_log):
71+
with tempfile.TemporaryDirectory() as tempdir:
72+
# set up engine
73+
def _train_func(engine, batch):
74+
return [batch + 1.0]
75+
76+
engine = Engine(_train_func)
77+
78+
# set up dummy metric
79+
@engine.on(Events.EPOCH_COMPLETED)
80+
def _update_metric(engine):
81+
current_metric = engine.state.metrics.get("acc", 0.1)
82+
engine.state.metrics["acc"] = current_metric + 0.1
83+
84+
# set up testing handler
85+
stats_handler = TensorBoardStatsHandler(log_dir=tempdir, iteration_log=False, epoch_log=epoch_log)
86+
stats_handler._default_epoch_writer = MagicMock()
87+
stats_handler.attach(engine)
88+
89+
max_epochs = 4
90+
engine.run(range(3), max_epochs=max_epochs)
91+
stats_handler.close()
92+
# check logging output
93+
if epoch_log is True:
94+
self.assertEqual(stats_handler._default_epoch_writer.call_count, max_epochs)
95+
else:
96+
self.assertEqual(stats_handler._default_epoch_writer.call_count, 2) # 2 = len([1, 2]) from event_filter
97+
5098
def test_metrics_writer(self):
5199
with tempfile.TemporaryDirectory() as tempdir:
52100
# set up engine
@@ -78,6 +126,47 @@ def _update_metric(engine):
78126
# check logging output
79127
self.assertTrue(len(glob.glob(tempdir)) > 0)
80128

129+
@parameterized.expand([[True], [get_event_filter([1, 3])]])
130+
def test_metrics_writer_mock(self, iteration_log):
131+
with tempfile.TemporaryDirectory() as tempdir:
132+
# set up engine
133+
def _train_func(engine, batch):
134+
return [batch + 1.0]
135+
136+
engine = Engine(_train_func)
137+
138+
# set up dummy metric
139+
@engine.on(Events.EPOCH_COMPLETED)
140+
def _update_metric(engine):
141+
current_metric = engine.state.metrics.get("acc", 0.1)
142+
engine.state.metrics["acc"] = current_metric + 0.1
143+
engine.state.test = current_metric
144+
145+
# set up testing handler
146+
writer = SummaryWriter(log_dir=tempdir)
147+
stats_handler = TensorBoardStatsHandler(
148+
summary_writer=writer,
149+
iteration_log=iteration_log,
150+
epoch_log=False,
151+
output_transform=lambda x: {"loss": x[0] * 2.0},
152+
global_epoch_transform=lambda x: x * 3.0,
153+
state_attributes=["test"],
154+
)
155+
stats_handler._default_iteration_writer = MagicMock()
156+
stats_handler.attach(engine)
157+
158+
num_iters = 3
159+
max_epochs = 2
160+
engine.run(range(num_iters), max_epochs=max_epochs)
161+
writer.close()
162+
163+
if iteration_log is True:
164+
self.assertEqual(stats_handler._default_iteration_writer.call_count, num_iters * max_epochs)
165+
else:
166+
self.assertEqual(
167+
stats_handler._default_iteration_writer.call_count, 2
168+
) # 2 = len([1, 3]) from event_filter
169+
81170

82171
if __name__ == "__main__":
83172
unittest.main()

0 commit comments

Comments
 (0)