Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,13 @@ def run(self):
from ignite.utils import *

manual_seed(666)

def evaluator(metric, metric_name):
def process_function(engine, batch):
y_pred, y = batch
return y_pred, y
engine = Engine(process_function)
metric.attach(engine, metric_name)
return engine

"""
22 changes: 22 additions & 0 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,28 @@ def thresholded_output_transform(output):
device: specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

``y_pred`` and ``y`` should have the same shape.

.. testcode::


metric_name='accuracy'
engine=evaluator(Accuracy(), metric_name)
Copy link
Copy Markdown
Collaborator

@vfdev-5 vfdev-5 Oct 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chandan-h-509 thanks for the update, but (here I agree with your previous comment) it is non-obvious to understand what it is evaluator with its particular API as evaluator(Accuracy(), metric_name). I think a tradeoff would be something like

            metric = Accuracy()
            metric.attach(evaluator, "accuracy")

            preds = torch.tensor([[1,0,0,1], ])
            target = torch.tensor([[1,0,0,0], ])
            state = evaluator.run([[preds, target], ])
            print(state.metrics["accuracy"])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition, I think it would be more helpful to show all possible options while computing accuracy:

  1. binary case
  2. multiclass case
  3. multilabel case

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chandan-h-509 thanks for the update, but (here I agree with your previous comment) it is non-obvious to understand what it is evaluator with its particular API as evaluator(Accuracy(), metric_name). I think a tradeoff would be something like

            metric = Accuracy()
            metric.attach(evaluator, "accuracy")

            preds = torch.tensor([[1,0,0,1], ])
            target = torch.tensor([[1,0,0,0], ])
            state = evaluator.run([[preds, target], ])
            print(state.metrics["accuracy"])

@vfdev-5 I didn't understand the solution clearly.
Could you please elaborate more on it.

Copy link
Copy Markdown
Collaborator

@vfdev-5 vfdev-5 Oct 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Chandan-h-509 current docstring you proposing in this PR does not look good especially this line:

engine=evaluator(Accuracy(), metric_name)

because it is not obvious to understand what it is evaluator method and what is its API.
I was proposing you to rework the docstring and doctest_global_setup such that the new docstring ressembles to https://github.com/pytorch/ignite/pull/2287/files#r735168912

Additionally, you need to add more docstrings for other cases as described https://github.com/pytorch/ignite/pull/2287/files#r735169293

preds = torch.Tensor([[1,0,0,1]])
target = torch.Tensor([[1,0,0,0]])
state = engine.run([[preds, target]])
print(state.metrics[metric_name])

.. testoutput::

0.75
"""

def __init__(
Expand Down