-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Add SaliencyInferer #1985
Copy link
Copy link
Closed
Labels
Description
Is your feature request related to a problem? Please describe.
SaliencyInferer is the normal inference method that run model forward() directly.
Describe the solution you'd like
from monai.inferers.inferer import Inferer
from monai.visualize import GradCAM
class SaliencyInferer(Inferer):
"""
SaliencyInferer is the normal inference method that run model forward() directly.
Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.
"""
def __init__(self, target_layers=None, class_idx=None) -> None:
Inferer.__init__(self)
assert target_layers is not None, "Need to specify 'target_layers' for GradCAM."
self.target_layers = target_layers
self.class_idx = class_idx
def __call__(
self,
inputs: torch.Tensor,
network: Callable[..., torch.Tensor],
*args: Any,
**kwargs: Any,
):
"""Unified callable function API of Inferers.
Args:
inputs: model input data for inference.
network: target model to execute inference.
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
args: optional args to be passed to ``network``.
kwargs: optional keyword args to be passed to ``network``.
"""
cam = GradCAM(nn_module=network, target_layers=self.target_layers)
return cam(x=inputs, class_idx=self.class_idx)Reactions are currently unavailable