Skip to content

2493 Fix decollate logic in integration test#2494

Merged
Nic-Ma merged 13 commits intoProject-MONAI:devfrom
Nic-Ma:fix-integration-decollate
Jul 2, 2021
Merged

2493 Fix decollate logic in integration test#2494
Nic-Ma merged 13 commits intoProject-MONAI:devfrom
Nic-Ma:fix-integration-decollate

Conversation

@Nic-Ma
Copy link
Copy Markdown
Contributor

@Nic-Ma Nic-Ma commented Jul 1, 2021

Fixes #2493 .

Description

This PR added the missing decollate logic in integration test and also enhanced plot API to support new data shape.

Status

Ready

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@Nic-Ma Nic-Ma requested a review from yiheng-wang-nv July 1, 2021 07:28
@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

/black

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

/integration-test

@Nic-Ma Nic-Ma requested review from ericspod, rijobro and wyli July 1, 2021 07:30
@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

/black

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

Hi @rijobro and @wyli ,

This PR also slightly changed the logic of decollate_batch:
Removed the logic to convert scalar Tensor to number, otherwise, the output data type is uncertain and the following post transforms may not work, which expect Tensor input.
I didn't see a use case that needs to convert scalar tensor to number when decollating, and I think the converting logic should not be involved in the decollate utility API.
What do you think?

Thanks.

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

Hi @rijobro and @wyli ,

This PR also slightly changed the logic of decollate_batch:
Removed the logic to convert scalar Tensor to number, otherwise, the output data type is uncertain and the following post transforms may not work, which expect Tensor input.
I didn't see a use case that needs to convert scalar tensor to number when decollating, and I think the converting logic should not be involved in the decollate utility API.
What do you think?

Thanks.

Please ignore this comment, I restored back to be non-breaking.

Thanks.

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

Hi @rijobro @wyli ,

I found an blocking issue when developing this PR:
Currently, our decollate logic has an option to convert scalar Tensor to regular number, but if True, the prediction or label of classification task will be decollated into number instead of Tensor, which is not compatible with post transforms, if False, many saved parameters of transforms will become Tensor when inverting and raise error.

So the workflows of our current dev branch can't work with classification task.
What would you suggest me to modify in this PR to fix this issue?

Thanks.

@Nic-Ma Nic-Ma force-pushed the fix-integration-decollate branch from 1cb8c03 to 4f599ce Compare July 1, 2021 15:34
@rijobro
Copy link
Copy Markdown
Contributor

rijobro commented Jul 1, 2021

hi @Nic-Ma , do we not have the possibility of leaving things as they are? As in, if part of the input is Tensor, return it as a Tensor. If it's a float (for example), then return it as a float. That would avoid the binary problem that you're seeing, wouldn't it?

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 1, 2021

Hi @rijobro @wyli ,

Yes, this idea can fix the binary issue I am facing, but as the input label is long / int for the classification task, it will be decollated to a number, which is not compatible with post transforms.
Now I am thinking to add below logic to the beginning of several post transforms to ensure the input data is channel-first tensor:

# ensure input data is a PyTorch Tensor
img = torch.as_tensor(img)
# ensure `channel-first` Tensor
img = img.unsqueeze(0) if img.ndim == 0 else img

What do you think?

Thanks.

@rijobro
Copy link
Copy Markdown
Contributor

rijobro commented Jul 1, 2021

I'm a little lost, I think we're talking about classification labels, but in your code snippet you used img. Have I missed something?

Are you suggesting adding an extra dimension to the classification labels so that when they get decollated they stay as a tensor?

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jul 1, 2021

I think the root cause is that both non-tensor and tensor values are collated into torch tensors by default https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L67-L72
we couldn't tell if it's originally a non-tensor or tensor at decollate time...

@madil90
Copy link
Copy Markdown
Contributor

madil90 commented Jul 1, 2021

/build

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Hi @rijobro and @wyli ,

As @wyli summarized, the root cause is that we don't know the original data type when decollating.
So we can:

  1. try to add the channel dim for label in the classification task as @rijobro suggested(but I don't get a right place to add this logic, we must add it after model forward and loss computation, some PyTorch losses don't support label with channel dim).
  2. Or add some logic to convert non-tensor data back to tensor as I said in the previous comment(I feel it's also not easy to do, don't get a right place to convert it, because we need to make sure the prediction and labels are tensors for post transforms, metrics, visualization, saving, etc.).

Do you guys have any other ideas?

Thanks.

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Hi @ericspod ,

I think we are a little bit blocked here, welcome to join the discussion if you have any idea for this decollate problem.
@wyli Is there some way to hack the pred and label of classification task to make it a Tensor after decollating?
Thanks in advance.

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jul 2, 2021

Hi @ericspod ,

I think we are a little bit blocked here, welcome to join the discussion if you have any idea for this decollate problem.
@wyli Is there some way to hack the pred and label of classification task to make it a Tensor after decollating?
Thanks in advance.

just for clarification,

  1. we want to have some transform parameters remain non-tensors after decollating, e.g. we expect {"interp_order": 0} and {"interp_order": (0, 1)}.
  2. we want to decollate the labels into tensor, e.g. we expect {"label": tensor(0)} and {"class_label": tensor([0, 1])}

so, a workaround would be making the labels to be always at least 2d tensor? we ensure the classification labels are channel-first, spatial_dim=1, e.g. we expect {"label": tensor([[0]])} and {"class_label": tensor([[0, 1]])}

for the pytorch loss functions, I think the APIs may assume different input dimensions anyway...e.g.

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Hi @wyli ,

Thanks for your summary, actually, I also tested PyTorch loss functions with channel dim in labels, it failed:

>>> import torch
>>> loss = torch.nn.CrossEntropyLoss()
>>> 
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, 1, dtype=torch.long).random_(5)
>>> print(target)
tensor([[1],
        [3],
        [1]])
>>> output = loss(input, target)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 1120, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 2824, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: 1D target tensor expected, multi-target not supported

Thanks.

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jul 2, 2021

Hi @ericspod ,
I think we are a little bit blocked here, welcome to join the discussion if you have any idea for this decollate problem.
@wyli Is there some way to hack the pred and label of classification task to make it a Tensor after decollating?
Thanks in advance.

just for clarification,

  1. we want to have some transform parameters remain non-tensors after decollating, e.g. we expect {"interp_order": 0} and {"interp_order": (0, 1)}.
  2. we want to decollate the labels into tensor, e.g. we expect {"label": tensor(0)} and {"class_label": tensor([0, 1])}

so, a workaround would be making the labels to be always at least 2d tensor? we ensure the classification labels are channel-first, spatial_dim=1, e.g. we expect {"label": tensor([[0]])} and {"class_label": tensor([[0, 1]])}

for the pytorch loss functions, I think the APIs may assume different input dimensions anyway...e.g.

ok, in the worst case, for case 2 we return numpy arrays for classification label outputs, what would be the issue? I don't think we need complicated postprocessing transforms for the label vector... so I don't think it is a blocking issue

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Hi @wyli ,

Do you mean to change 0-dim tensor into numpy array?
We need label to be Tensor for post transforms like AsDiscrete(one_hot=True), and metrics like AUC, etc.

Thanks.

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jul 2, 2021

Hi @wyli ,

Do you mean to change 0-dim tensor into numpy array?
We need label to be Tensor for post transforms like AsDiscrete(one_hot=True), and metrics like AUC, etc.

Thanks.

If tensor is needed, in the postprocessing transform we can still add a ToTensorD, and problem solved, correct?

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Hi @wyli ,

Thanks for your suggestion.
Yes, ToTensor may be the best solution for now,where do you think we should document something to users to make them understand why they need to add ToTensor again in post transform?

Thanks.

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jul 2, 2021

Hi @wyli ,

Thanks for your suggestion.
Yes, ToTensor may be the best solution for now,where do you think we should document something to users to make them understand why they need to add ToTensor again in post transform?

Thanks.

Cool, now it's just a documentation issue. I think we need a migration guide to explain all the decollating feature, it's a major change since 0.5.3. (I've already received questions about it)

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Jul 2, 2021

Hi @wyli ,

Thanks for your suggestion.
Yes, ToTensor may be the best solution for now,where do you think we should document something to users to make them understand why they need to add ToTensor again in post transform?

Thanks.

We can rename the "ToTensor" transform to "EnsureTensor" transform, maybe that sounds more elegant

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Hi @wyli ,

Plan sounds good to me. Let me work on these following PRs next week.

  1. EnsureTensor() exactly sounds better, especially considering we already have the EnsureChannelFirst() transform, I think we can keep the ToTensor() transform and make it non-breaking, just like our AddChannel(), etc.
    Submitted ticket to track: Add EnsureTensor transform #2511.
  2. Let me prepare a Decollate data page and let's put it in the README for next release.
    Submitted ticket to track: Add document about migrating to decollate feature (7/July) #2510.

Thanks.

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

/black

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

/integration-test

@Nic-Ma
Copy link
Copy Markdown
Contributor Author

Nic-Ma commented Jul 2, 2021

Integration tests passed locally with V100 GPU.

Thanks.

Copy link
Copy Markdown
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! we'll have some follow-ups as discussed in the PR comments

@Nic-Ma Nic-Ma enabled auto-merge (squash) July 2, 2021 15:04
@Nic-Ma Nic-Ma merged commit 6814239 into Project-MONAI:dev Jul 2, 2021
@Nic-Ma Nic-Ma deleted the fix-integration-decollate branch July 2, 2021 23:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Segmentation / Classification integration tests forget to use decollate

6 participants