Skip to content

1651 implement RegUNet#1658

Merged
wyli merged 6 commits intoProject-MONAI:masterfrom
kate-sann5100:1651-regunet
Mar 9, 2021
Merged

1651 implement RegUNet#1658
wyli merged 6 commits intoProject-MONAI:masterfrom
kate-sann5100:1651-regunet

Conversation

@kate-sann5100
Copy link
Copy Markdown
Collaborator

@kate-sann5100 kate-sann5100 commented Feb 26, 2021

Signed-off-by: kate-sann5100 [email protected]

Fixes # 1651.

Description

Implements RegUNet which will serve as the parent class of the LocalNet and GlobalNet

Status

Ready/Work in progress/Hold

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh --codeformat --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@wyli
Copy link
Copy Markdown
Contributor

wyli commented Mar 1, 2021

is it possible to unify this network and the existing unet? otherwise we should highlight the key differences between them, so that the users wouldn't get confused... could you help us on this @ericspod?

@ericspod
Copy link
Copy Markdown
Member

ericspod commented Mar 4, 2021

The way our basic UNet is structured is around template methods called by the constructor when putting together what is a tower of layers. Each layer consists of a down sampling block, and skip connect, and the upsampling block. The path of data thus flows through the downsampling block, down to the next layer, when back up from that layer through the upsampling block. Think of squeezing the U structure of UNet into a tower. The important thing is the class represents the abstract structure and relies on methods to provide the blocks, for example _get_down_layer defines the block for one layer of the down path.

The way to adapt other UNet classes is to subclass our basic UNet and provide overrides for these methods, then in the constructor call the inherited constructor to create the structure. If the arguments passed to them aren't sufficient for what you're doing you can set some other state in the before calling the inherited constructor for the methods to pick up on. If other outputs are needed, for example DynUNet collecting outputs throughout the upsampling path, then some additional structure would need to be added to collect these outputs, I've been meaning to have a go at adapting that class but it's still being worked on.

With your RegUNet you have a lot of that already factored out, the difference is that you're explicitly handling the forward pass through the U structure, keeping track of the skip connections manually. I think this class can be adapted relatively easily by overriding the methods from the UNet to do what you have currently in the build_* methods which should allow you to avoid having a custom forward method entirely.

@kate-sann5100
Copy link
Copy Markdown
Collaborator Author

kate-sann5100 commented Mar 4, 2021

I have drawn the architectures of RegUNet, DynUNet, UNet, and BasicUNet in the following diagram.
Unet
As shown:

  1. DynUNet and RegUNet are a bit more complicated than UNet and BasicUNet because they need to store decoded features at different levels.
  2. DynUnet, UNet and BasicUNet generate blocks in a recursive way while RegUNet generates blocks through for loops. While the recursive generation method allows simpler implementation for forward method, the for loop implementation is more straight forward to understand. I would say I don't have a specific preference between those two methods.
  3. I suggest we decide first which block-generation method to adopt and what uniform methods to define.

@wyli wyli mentioned this pull request Mar 5, 2021
@ericspod
Copy link
Copy Markdown
Member

ericspod commented Mar 5, 2021

What I would suggest architecturally is to wrap the decode layers in a simple class which captures the output of the forward pass as a local variable. Once a forward pass through all the layers has been done a loop can go through these and collect the outputs which in RegUNet you sum together, in SynUNet these would be passed through the superheads. I think this will work and be compatible with Torchscript.

@kate-sann5100
Copy link
Copy Markdown
Collaborator Author

@ericspod
That sounds great. Do you want to write the SynUnet or I am happy to write it?

Copy link
Copy Markdown
Contributor

@wyli wyli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! we've agreed that it'll take more iterations to unify the different net APIs. (this PR supports torchscript with a unit test)

@wyli wyli enabled auto-merge (squash) March 9, 2021 14:52
@wyli wyli merged commit 78ec66f into Project-MONAI:master Mar 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants