Add a handle_buffers option for EMAHandler#2592
Conversation
|
@sandylaker thanks for a quick PR! As for the problem about integral params in buffers, does it mean that in pytorch they should here a similar issue here: https://github.com/pytorch/pytorch/blob/080cf84bed46c6c118c37fc2fa6fbd484fd9b4cd/torch/optim/swa_utils.py#L124-L132 ? |
Yeah they should, but the default usage for the SWA model is to update the batchnorm after training and not during via https://github.com/pytorch/pytorch/blob/080cf84bed46c6c118c37fc2fa6fbd484fd9b4cd/torch/optim/swa_utils.py#L136 . Also as I would hypothesize, that most code and examples usually set the model to .train() before transfer, thus the problem is less likely to occur. Actually, I once trained a common Mean-teacher recipe and checked the absolute difference between the online (student) model and the EMA model. |
|
@vfdev-5 I did not test the >>> num_batches_tracked = torch.tensor(0, dtype=torch.int64)
>>> num_batches_tracked.copy_(num_batches_tracked * 0.9998 + 1 * 0.0002)
tensor(0)I think the rounding error will be more severe when the online value is small. In the current implementation, the |
|
@sandylaker if you check again "Solutions" of #2590 it looks like that RicherMans already tried to update floating buffers and copy integers (see his code snippet). For this PR I think we could add an arg with 3 options: 1) "copy" = to keep current (copy) behaviour, 2) "update_buffers" = update buffers according to Richer's option 1 and 3) "ema_train" = Richer's option 2 What do you think ? |
So I actually did already check during my investigations all the proposed methods. So to be more precise for my experiments, I do sound event classification for the DCASE 2022 task, where the baseline uses a mean-teacher approach that I have reimplemented using EMAHandler. The results are as follows:
So as I originally stated, the momentum update of buffers such as the mean and variance works obviously better than the simple copying mechanism. Of course this is a mean-teacher learning case, and my results might not be representative for other cases where one might want to "freeze" the buffers. What do you guys think? |
|
@RicherMans thanks for sharing details ! I think we could provide all 3 options such that depending on a use-case user could pick appropriate configuration (as suggested above: #2592 (comment)). We have to ensure that the 3rd option |
|
@sandylaker can we move forward with this PR and implement 3 options described in #2592 (comment) ? Thank you |
|
@vfdev-5 Hi. Sorry for the late reply, I have been busy during the working days. I will implement it this weekend. |
use_buffers option for EMAHandlerhandle_buffers option for EMAHandler
vfdev-5
left a comment
There was a problem hiding this comment.
Looks great, thanks a lot @sandylaker !
|
@RicherMans could you please check this PR if it solves the issue you have |
|
I merged this PR and if we need more changes to it, we can do that it in a follow-up PR. |
|
Sorry @vfdev-5 for my late response. Thanks guys for the good work! ( Even though I hope I can also one time commit myself ) |
That would be awesome ! By the way, if you need some help to start with that, feel free to join our discord: https://pytorch-ignite.ai/chat |
Fixes #2590
Description:
Add a
handle_buffersoption for EMAHandler.cc @vfdev-5 @RicherMans
Check list: