Conversation
Signed-off-by: kate-sann5100 <[email protected]>
Signed-off-by: kate-sann5100 <[email protected]>
Signed-off-by: kate-sann5100 <[email protected]>
|
is it possible to unify this network and the existing unet? otherwise we should highlight the key differences between them, so that the users wouldn't get confused... could you help us on this @ericspod? |
|
The way our basic UNet is structured is around template methods called by the constructor when putting together what is a tower of layers. Each layer consists of a down sampling block, and skip connect, and the upsampling block. The path of data thus flows through the downsampling block, down to the next layer, when back up from that layer through the upsampling block. Think of squeezing the U structure of UNet into a tower. The important thing is the class represents the abstract structure and relies on methods to provide the blocks, for example _get_down_layer defines the block for one layer of the down path. The way to adapt other UNet classes is to subclass our basic UNet and provide overrides for these methods, then in the constructor call the inherited constructor to create the structure. If the arguments passed to them aren't sufficient for what you're doing you can set some other state in the before calling the inherited constructor for the methods to pick up on. If other outputs are needed, for example DynUNet collecting outputs throughout the upsampling path, then some additional structure would need to be added to collect these outputs, I've been meaning to have a go at adapting that class but it's still being worked on. With your RegUNet you have a lot of that already factored out, the difference is that you're explicitly handling the forward pass through the U structure, keeping track of the skip connections manually. I think this class can be adapted relatively easily by overriding the methods from the UNet to do what you have currently in the |
|
What I would suggest architecturally is to wrap the decode layers in a simple class which captures the output of the forward pass as a local variable. Once a forward pass through all the layers has been done a loop can go through these and collect the outputs which in RegUNet you sum together, in SynUNet these would be passed through the superheads. I think this will work and be compatible with Torchscript. |
|
@ericspod |
wyli
left a comment
There was a problem hiding this comment.
thanks! we've agreed that it'll take more iterations to unify the different net APIs. (this PR supports torchscript with a unit test)

Signed-off-by: kate-sann5100 [email protected]
Fixes # 1651.
Description
Implements RegUNet which will serve as the parent class of the LocalNet and GlobalNet
Status
Ready/Work in progress/Hold
Types of changes
./runtests.sh --codeformat --coverage../runtests.sh --quick.make htmlcommand in thedocs/folder.