Skip to content

Deep supervision loss wrapper class#5338

Merged
wyli merged 8 commits intoProject-MONAI:devfrom
myron:ds_loss
Oct 17, 2022
Merged

Deep supervision loss wrapper class#5338
wyli merged 8 commits intoProject-MONAI:devfrom
myron:ds_loss

Conversation

@myron
Copy link
Copy Markdown
Collaborator

@myron myron commented Oct 15, 2022

This adds a DeepSupervisionLoss wrapper class around the main loss function to accept a list of tensors returned from a deeply supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels (accounting for potential differences in shapes between targets and ds outputs)

The wrapper class is designed to work with arbitrary existing loss,e.g.

    loss = DiceCELoss(to_onehot_y=True, softmax=True) 
    ds_loss = DeepSupervisionLoss(loss)

Whereas the existing loss accepts the input as a single Tensor, ds_loss accepts the input as a list of Tensors (for each output of a deeply supervised network). If only a simple Tensor input is provided, ds_loss behaves exactly the same as the underlying loss.

I added unit tests too.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@myron myron added this to the Auto3D Seg framework [P0 v1.0] milestone Oct 17, 2022
Copy link
Copy Markdown
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please help add an antialiasing option for the downsampling? (#3178) F.interpolate has the option for 2D as well. the 3d could be done with the recent gaussian/median smoothing

"MedianSmooth",
"GaussianSmooth",

@myron
Copy link
Copy Markdown
Collaborator Author

myron commented Oct 17, 2022

could you please help add an antialiasing option for the downsampling? (#3178) F.interpolate has the option for 2D as well. the 3d could be done with the recent gaussian/median smoothing

"MedianSmooth",
"GaussianSmooth",

good questions, but antialiasing is applicable only with "linear" interpolation, and here we use "nearest" interpolation of integer targets , so we can't use antialiasing (such as gaussian smoothing). At least in my use-cases, the target (is an image with integer class label affiliations). I could add "antialiasing" (for F.interpolate) support for some future hypothetical uses-cases of soft targets, but we won't be able to test them in real-application yet (and won't know the impact/accuracy of such antialising). Do you think we need this support for a "hypothetical" use-cases?

another problem with F.interpolate (antialising) - it will be much slower, and we need to run loss function frequently and fast. The Median smoothing will be even slower. I suppose for those cases, a better strategy would be to pre-compute targets at multiple-scales for all images (before training) and use them all as targets. But that would require another "loss function" or extending this.

perhaps we can start using this class as is for now, since we need it, now come back to extending it when we have a use-case for it.

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Oct 17, 2022

/build

@wyli wyli enabled auto-merge (squash) October 17, 2022 16:31
@wyli wyli merged commit 982131c into Project-MONAI:dev Oct 17, 2022
wyli pushed a commit that referenced this pull request Oct 19, 2022
A small fix to followup on
#5338
to ensure  ds_loss returns a constant, not an array

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
wyli pushed a commit that referenced this pull request Oct 24, 2022
This adds a DeepSupervisionLoss wrapper class around the main loss
function to accept a list of tensors returned from a deeply supervised
networks. The final loss is computed as the sum of weighted losses for
each of deep supervision levels (accounting for potential differences in
shapes between targets and ds outputs)

The wrapper class is designed to work with arbitrary existing loss,e.g.
```
    loss = DiceCELoss(to_onehot_y=True, softmax=True) 
    ds_loss = DeepSupervisionLoss(loss)
```

Whereas the existing loss accepts the input as a single Tensor, ds_loss
accepts the input as a list of Tensors (for each output of a deeply
supervised network). If only a simple Tensor input is provided, ds_loss
behaves exactly the same as the underlying loss.

I added unit tests too.



### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
wyli pushed a commit that referenced this pull request Oct 24, 2022
A small fix to followup on
#5338
to ensure  ds_loss returns a constant, not an array

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
wyli pushed a commit that referenced this pull request Oct 24, 2022
This adds a DeepSupervisionLoss wrapper class around the main loss
function to accept a list of tensors returned from a deeply supervised
networks. The final loss is computed as the sum of weighted losses for
each of deep supervision levels (accounting for potential differences in
shapes between targets and ds outputs)

The wrapper class is designed to work with arbitrary existing loss,e.g.
```
    loss = DiceCELoss(to_onehot_y=True, softmax=True) 
    ds_loss = DeepSupervisionLoss(loss)
```

Whereas the existing loss accepts the input as a single Tensor, ds_loss
accepts the input as a list of Tensors (for each output of a deeply
supervised network). If only a simple Tensor input is provided, ds_loss
behaves exactly the same as the underlying loss.

I added unit tests too.



### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
wyli pushed a commit that referenced this pull request Oct 24, 2022
A small fix to followup on
#5338
to ensure  ds_loss returns a constant, not an array

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
KumoLiu pushed a commit that referenced this pull request Nov 2, 2022
This adds a DeepSupervisionLoss wrapper class around the main loss
function to accept a list of tensors returned from a deeply supervised
networks. The final loss is computed as the sum of weighted losses for
each of deep supervision levels (accounting for potential differences in
shapes between targets and ds outputs)

The wrapper class is designed to work with arbitrary existing loss,e.g.
```
    loss = DiceCELoss(to_onehot_y=True, softmax=True) 
    ds_loss = DeepSupervisionLoss(loss)
```

Whereas the existing loss accepts the input as a single Tensor, ds_loss
accepts the input as a list of Tensors (for each output of a deeply
supervised network). If only a simple Tensor input is provided, ds_loss
behaves exactly the same as the underlying loss.

I added unit tests too.



### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
Signed-off-by: KumoLiu <[email protected]>
KumoLiu pushed a commit that referenced this pull request Nov 2, 2022
A small fix to followup on
#5338
to ensure  ds_loss returns a constant, not an array

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: myron <[email protected]>
Signed-off-by: KumoLiu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants