Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ignite/engine/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable


def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: Any) -> None:
Expand All @@ -21,7 +21,7 @@ def _check_signature(fn: Callable, fn_description: str, *args: Any, **kwargs: An
)


def _to_hours_mins_secs(time_taken: Union[float, int]) -> Tuple[int, int, float]:
def _to_hours_mins_secs(time_taken: float | int) -> tuple[int, int, float]:
"""Convert seconds to hours, mins, seconds and milliseconds."""
mins, secs = divmod(time_taken, 60)
hours, mins = divmod(mins, 60)
Expand Down
40 changes: 20 additions & 20 deletions ignite/handlers/clearml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import defaultdict
from datetime import datetime
from enum import Enum
from typing import Any, Callable, DefaultDict, List, Mapping, Optional, Tuple, Type, Union
from typing import Any, Callable, DefaultDict, Mapping, Type

from torch.optim import Optimizer

Expand Down Expand Up @@ -325,16 +325,16 @@ def global_step_transform(engine, event_name):
def __init__(
self,
tag: str,
metric_names: Optional[Union[List[str], str]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable[[Engine, Union[str, Events]], int]] = None,
state_attributes: Optional[List[str]] = None,
metric_names: list[str] | str | None = None,
output_transform: Callable | None = None,
global_step_transform: Callable[[Engine, str | Events], int] | None = None,
state_attributes: list[str] | None = None,
):
super(OutputHandler, self).__init__(
tag, metric_names, output_transform, global_step_transform, state_attributes
)

def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: str | Events) -> None:
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler OutputHandler works only with ClearMLLogger")

Expand Down Expand Up @@ -392,10 +392,10 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
"""

def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: str | None = None):
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)

def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: str | Events) -> None:
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler OptimizerParamsHandler works only with ClearMLLogger")

Expand Down Expand Up @@ -489,7 +489,7 @@ def has_bias_in_name(n, p):
optional argument `whitelist` added.
"""

def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: str | Events) -> None:
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler WeightsScalarHandler works only with ClearMLLogger")

Expand Down Expand Up @@ -579,7 +579,7 @@ class WeightsHistHandler(BaseWeightsHandler):
optional argument `whitelist` added.
"""

def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: str | Events) -> None:
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler 'WeightsHistHandler' works only with ClearMLLogger")

Expand Down Expand Up @@ -675,7 +675,7 @@ def is_in_fc_layer(n, p):
optional argument `whitelist` added.
"""

def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: str | Events) -> None:
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler GradsScalarHandler works only with ClearMLLogger")

Expand Down Expand Up @@ -765,7 +765,7 @@ def has_shape_2_1(n, p):
optional argument `whitelist` added.
"""

def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: Union[str, Events]) -> None:
def __call__(self, engine: Engine, logger: ClearMLLogger, event_name: str | Events) -> None:
if not isinstance(logger, ClearMLLogger):
raise RuntimeError("Handler 'GradsHistHandler' works only with ClearMLLogger")

Expand Down Expand Up @@ -828,9 +828,9 @@ class ClearMLSaver(DiskSaver):

def __init__(
self,
logger: Optional[ClearMLLogger] = None,
output_uri: Optional[str] = None,
dirname: Optional[str] = None,
logger: ClearMLLogger | None = None,
output_uri: str | None = None,
dirname: str | None = None,
*args: Any,
**kwargs: Any,
):
Expand All @@ -850,7 +850,7 @@ def __init__(
if "atomic" not in kwargs:
kwargs["atomic"] = False

self._checkpoint_slots: DefaultDict[Union[str, Tuple[str, str]], List[Any]] = defaultdict(list)
self._checkpoint_slots: DefaultDict[str | tuple[str, str], list[Any]] = defaultdict(list)

super(ClearMLSaver, self).__init__(dirname=dirname, *args, **kwargs) # type: ignore[misc]

Expand Down Expand Up @@ -885,11 +885,11 @@ class _CallbacksContext:
def __init__(
self,
callback_type: Type[Enum],
slots: List,
slots: list,
checkpoint_key: str,
filename: str,
basename: str,
metadata: Optional[Mapping] = None,
metadata: Mapping | None = None,
) -> None:
self._callback_type = callback_type
self._slots = slots
Expand Down Expand Up @@ -930,7 +930,7 @@ def post_callback(self, action: str, model_info: Any) -> Any:

return model_info

def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
def __call__(self, checkpoint: Mapping, filename: str, metadata: Mapping | None = None) -> None:
try:
from clearml.binding.frameworks import WeightsFileHandler
except ImportError:
Expand Down Expand Up @@ -970,7 +970,7 @@ def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mappin
WeightsFileHandler.remove_post_callback(post_cb_id)

@idist.one_rank_only()
def get_local_copy(self, filename: str) -> Optional[str]:
def get_local_copy(self, filename: str) -> str | None:
"""Get artifact local copy.

.. warning::
Expand Down
Loading