Implementation of a Masked Autoencoder for representation learning#8152
Implementation of a Masked Autoencoder for representation learning#8152KumoLiu merged 9 commits intoProject-MONAI:devfrom
Conversation
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
|
Hi @Lucas-rbnt thanks for the effort on this followup PR. @atbenmurray could you please re-review the content here? |
|
@Lucas-rbnt @atbenmurray I shall do so |
KumoLiu
left a comment
There was a problem hiding this comment.
Thanks for the PR.
In the official masked autoencoder implementation, noise is first generated and then sorted twice using torch.argsort. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.
In our implementation, we use torch.multinomial to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.
As you mentioned here, I wonder have you verified that whether there is a big difference between the two different implementations? Does it have any impact on the final performance? Thanks.
Co-authored-by: YunLiu <[email protected]> Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
|
I think this is fine now though the comments should be looked at the conflict resolved, then we can trigger the blossom tests. Thanks! |
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
|
/build |
ericspod
left a comment
There was a problem hiding this comment.
I'm good with this and look forward to an example notebook in Tutorials demonstrating its use!
This follows a previous PR (#7598).
In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model.
In the official masked autoencoder implementation, noise is first generated and then sorted twice using
torch.argsort. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices.In our implementation, we use
torch.multinomialto generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder.Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.
Description
Implementation of the Masked Autoencoder as described in the paper: Masked Autoencoders Are Scalable Vision Learners from Kaiming et al.
Its effectiveness has already been demonstrated in the literature for medical tasks in the paper Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation.
The PR contains the architecture and associated unit tests.
Note: The output includes the prediction, which is a tensor of size: ($BS$ , $N_{tokens}$ , $D$ ), and the associated mask ($BS$ , $N_{tokens}$ ). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think?
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.