Skip to content

Add SaliencyInferer #1985

@Nic-Ma

Description

@Nic-Ma

Is your feature request related to a problem? Please describe.
SaliencyInferer is the normal inference method that run model forward() directly.

Describe the solution you'd like

from monai.inferers.inferer import Inferer
from monai.visualize import GradCAM


class SaliencyInferer(Inferer):
    """
    SaliencyInferer is the normal inference method that run model forward() directly.
    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.

    """

    def __init__(self, target_layers=None, class_idx=None) -> None:
        Inferer.__init__(self)
        assert target_layers is not None, "Need to specify 'target_layers' for GradCAM."
        self.target_layers = target_layers
        self.class_idx = class_idx

    def __call__(
        self,
        inputs: torch.Tensor,
        network: Callable[..., torch.Tensor],
        *args: Any,
        **kwargs: Any,
    ):
        """Unified callable function API of Inferers.

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        """

        cam = GradCAM(nn_module=network, target_layers=self.target_layers)

        return cam(x=inputs, class_idx=self.class_idx)

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions