Skip to content

[WIP] 561 dataset/partial metrics#565

Closed
mibaumgartner wants to merge 6 commits intoProject-MONAI:devfrom
mibaumgartner:561-dataset-metrics
Closed

[WIP] 561 dataset/partial metrics#565
mibaumgartner wants to merge 6 commits intoProject-MONAI:devfrom
mibaumgartner:561-dataset-metrics

Conversation

@mibaumgartner
Copy link
Copy Markdown
Collaborator

Fixes #561 #497 .

Description

This pull request is intended to provide a class based API for metric evaluation and add additional functionality to some metrics which can be evaluated on previously stored intermediate values.

Status

Work in progress. The current status is intended as an early draft of the targeted functionality (so some docstrings are not perfect and tests are missing). Would be great to get feedback for the following points:

  • General structure and design of the changed hierarchy inside the metrics package
  • Design of the exemplary ROCAUC class API (even though the full advantage of partial metrics can not be seen here 🙈 )
    • It is not always possible to call the functional method inside the evaluate function because most functions will store an intermediate representation of the data (e.g. for the Dice metric would store tp, fp .. which can not be passed to the functional interface). How should this be handled to reduce duplicate code?
  • Which metrics to include here (I would propose ROCAUC and Dice)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or new feature that would cause existing functionality to change)
  • New tests added to cover the changes
  • Docstrings/Documentation updated

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jun 16, 2020

thanks @mibaumgartner, this looks good. My understanding is that the PartialMetric interface is very useful for aggregating metrics from multiple batches/subjects/patients. Perhaps the class name could be CumulativeMetrics or AggregateMetrics?

I agree that we include ROCAUC and Dice. But the class name Dice is used to indicate single batch output. I think we need another class name for this cumulative Dice...

we could leave the intermediate caching for future PRs. My experience is that It'll reduce the code readability a lot. We'll consider multiple types of metrics sharing the same intermediate cache iff we identify serious performance bottlenecks in metrics computation -- we may run into these cases, for example, computing different types of scores where all scores are based on the same connected component analysis outcome. In this case we don't want to compute CCA multiple times

@mibaumgartner
Copy link
Copy Markdown
Collaborator Author

@wyli thank you for your quick feedback. CumulativeMetrics is a much better name 👍
I'll add the changes tomorrow :)

@Nic-Ma
Copy link
Copy Markdown
Contributor

Nic-Ma commented Jun 17, 2020

The overall idea is great.
As @myron already updated the MeanDice metric with a class API, could you please help update your PR to unify all the things?

Thanks.

@mibaumgartner
Copy link
Copy Markdown
Collaborator Author

The general functionality should be complete now. The points I changed:

  • update to master
  • CumulativeMetric now extends Metric by the ability to add samples. The final computation is still a just a call.
  • Renamed DiceMetric to Dice
  • extracted a subfunction from compute_dice to avoid code duplication

A few open points:

  • Was the renaming from DiceMetric to Dice okey?
  • Should cumulative metrics get individual samples or batches? (the current implementation assumes batches and the method should potentially be renamed to add_batch)
  • How should the device of the intermediate representation be handled for cumulative metrics? Should we introduce a parameter to decide this in the init?
    (e.g. device=None, does not touch the device, device='cpu' automatically pushed the intermediate values to the cpu to avoid GPU mem leaks)
  • Probably more important for future PRs: What should the metrics return? (float, tensor, numpy)

I'll update/add the docstrings and unittests tomorrow after we finalized the design :)

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jun 19, 2020

Thanks! I think the renaming looks good, and yes we need a device argument for the cumulative metrics (defaults to be the same as the input samples). they should return tensor for now, probably need another PR to make ROCAUC a tensor implementation

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jun 19, 2020

for the CumulativeMetric I'm also thinking about the use case of, for example, during each batch training/inference iterations, compute Dice score for each item in the batch, and after k iterations/epochs, aggregate the scores across the batches by computing mean/median/95 percentile...
Do we implement this case using the proposed API or is it out of the scope of this PR?

@mibaumgartner
Copy link
Copy Markdown
Collaborator Author

mibaumgartner commented Jun 19, 2020

At first I thought this would probably out of scope of this PR but after thinking about it a little bit, there might be a fairly simple solution which would fit pretty well in here.
The aggregation of batch specific results is pretty abstract so it could probably be shared by many metrics. I think the best solution would be to create a wrapper by subclassing CumulativeMetric which can take any other metric and aggregate its results over multiple batches. Could look something like this:

class AggregatedMetric(CumulativeMetric):
    def __init__(per_batch_metric, aggregation_mode, device):
        ...
   
    def __call__(self):
        # return aggregated metrics

    def add_sample(self, ..):
        # compute per_batch_metric and save it for later aggregation

Edit: Not sure if this is the most user friendly approach though 🙈

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jun 19, 2020

after looking at your code I think another slightly different option is to isolate the intermediate caching interface:

class ResultBuffer(ABC):
    def update(self, ...):
        # update the buffer with a new sample
        # this could support multi-thread/process context
        # e.g. inference on multiple samples in parallel
    def get_samples(self, ...):
        # read out the buffer or part of the buffer
        # one implementation would be:
        #     return torch.cat(self.y_pred_stored), torch.cat(self.y_stored)
        # or we could get the last 10 samples if needed:
        #     return torch.cat(self.y_pred_stored[-10:]), torch.cat(self.y_stored[-10:])
    def reset(self):
        # reset the buffer

then the Dice implementations would be:

class CumulativeDice(ResultBuffer, Metric):

class AggregateDice(ResultBuffer, Metric):

then in the near future, we could have fast aggregation Dice, something like:

# start a loader for pairs of groundtruth and prediction
pairwise_dataloader = torch.DataLoader(num_workers=5, ...)

agg_dice = AggregateDice()
with multiprocess.pool as p:  # this won't work in practice, just want to show the idea
    assert isinstance(agg_dice, ResultBuffer)
    p.map(agg_dice.update, [(g, t) for g, t in pairwise_dataloader]) 
agg_dice()  # get the final aggreated result

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jun 11, 2021

thanks for the discussion here, this will be closed in favour of #2314

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Additional Metric Interface to compute quantities over dataset

3 participants