[WIP] 561 dataset/partial metrics#565
[WIP] 561 dataset/partial metrics#565mibaumgartner wants to merge 6 commits intoProject-MONAI:devfrom
Conversation
|
thanks @mibaumgartner, this looks good. My understanding is that the PartialMetric interface is very useful for aggregating metrics from multiple batches/subjects/patients. Perhaps the class name could be I agree that we include ROCAUC and Dice. But the class name we could leave the intermediate caching for future PRs. My experience is that It'll reduce the code readability a lot. We'll consider multiple types of metrics sharing the same intermediate cache iff we identify serious performance bottlenecks in metrics computation -- we may run into these cases, for example, computing different types of scores where all scores are based on the same connected component analysis outcome. In this case we don't want to compute CCA multiple times |
|
@wyli thank you for your quick feedback. |
|
The overall idea is great. Thanks. |
|
The general functionality should be complete now. The points I changed:
A few open points:
I'll update/add the docstrings and unittests tomorrow after we finalized the design :) |
|
Thanks! I think the renaming looks good, and yes we need a device argument for the cumulative metrics (defaults to be the same as the input samples). they should return tensor for now, probably need another PR to make ROCAUC a tensor implementation |
|
for the |
|
At first I thought this would probably out of scope of this PR but after thinking about it a little bit, there might be a fairly simple solution which would fit pretty well in here. class AggregatedMetric(CumulativeMetric):
def __init__(per_batch_metric, aggregation_mode, device):
...
def __call__(self):
# return aggregated metrics
def add_sample(self, ..):
# compute per_batch_metric and save it for later aggregationEdit: Not sure if this is the most user friendly approach though 🙈 |
|
after looking at your code I think another slightly different option is to isolate the intermediate caching interface: class ResultBuffer(ABC):
def update(self, ...):
# update the buffer with a new sample
# this could support multi-thread/process context
# e.g. inference on multiple samples in parallel
def get_samples(self, ...):
# read out the buffer or part of the buffer
# one implementation would be:
# return torch.cat(self.y_pred_stored), torch.cat(self.y_stored)
# or we could get the last 10 samples if needed:
# return torch.cat(self.y_pred_stored[-10:]), torch.cat(self.y_stored[-10:])
def reset(self):
# reset the bufferthen the Dice implementations would be: class CumulativeDice(ResultBuffer, Metric):
class AggregateDice(ResultBuffer, Metric):then in the near future, we could have fast aggregation Dice, something like: # start a loader for pairs of groundtruth and prediction
pairwise_dataloader = torch.DataLoader(num_workers=5, ...)
agg_dice = AggregateDice()
with multiprocess.pool as p: # this won't work in practice, just want to show the idea
assert isinstance(agg_dice, ResultBuffer)
p.map(agg_dice.update, [(g, t) for g, t in pairwise_dataloader])
agg_dice() # get the final aggreated result |
|
thanks for the discussion here, this will be closed in favour of #2314 |
Fixes #561 #497 .
Description
This pull request is intended to provide a class based API for metric evaluation and add additional functionality to some metrics which can be evaluated on previously stored intermediate values.
Status
Work in progress. The current status is intended as an early draft of the targeted functionality (so some docstrings are not perfect and tests are missing). Would be great to get feedback for the following points:
evaluatefunction because most functions will store an intermediate representation of the data (e.g. for the Dice metric would store tp, fp .. which can not be passed to the functional interface). How should this be handled to reduce duplicate code?Types of changes