-
Notifications
You must be signed in to change notification settings - Fork 26.3k
C++ API: torch::nn::BatchNorm1d #28176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nuka137 My sincere apologies for the delay and thanks so much for the investigation and the awesome work! I left some comments regarding the design choices. The scope of work is large for BatchNorm{1,2,3}d and thanks so much for working on it. :D
|
|
||
| /// A momentum multiplier for the mean and variance. | ||
| /// Changing this parameter after construction __is effective__. | ||
| TORCH_ARG(double, momentum) = 0.1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Python version I believe we allow None as value for momentum, and to support it in C++ version I believe we'll need to use c10::optional<double> as type for momentum.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I fixed it.
| }; | ||
|
|
||
| template <size_t D> | ||
| struct BatchNormOptionsv2 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that we can probably name it BatchNormBaseOptions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion. I use BatchNormBaseOptions instead of BatchNormOptionsv2 .
|
|
||
| BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { | ||
| LOG(WARNING) << "torch::nn::BatchNorm module is deprecated." | ||
| << "Use BatchNorm{1,2,3}d instead."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think TORCH_WARN might be a better way to print the warning message, which is consistent with other parts of the C++ frontend :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with that. Change to TORCH_WARN.
| public: | ||
| using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase; | ||
|
|
||
| Tensor forward(const Tensor& input); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that we might be able to even move forward to the BatchNormImplBase class :D The Python version of forward seems to call F.batch_norm, and I think we can follow the same design and rename F::batch_norm1d to F::batch_norm to match the Python version even better :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I fixed it.
Also, I fixed the arguments of F::batch_norm too.
| Tensor BatchNorm1dImpl::forward(const Tensor& input) { | ||
| TORCH_CHECK( | ||
| input.dim() != 2 && input.dim() !=3, | ||
| "expected 2D or 3D input (got %dD input)", input.dim()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that we can add a virtual function _check_input_dim to the BatchNormImplBase class, override its implementation in BatchNorm1dImpl, and call this _check_input_dim function from forward, to match the Python version even better :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I added pure virtual function _check_input_dim to BatchNormImplBase class.
|
Thanks for reviewing. I fixed all issues. |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nuka137 Thanks so much and really appreciated the awesome work! I left some comments.
| TORCH_ARG(double, momentum) = 0.1; | ||
| }; | ||
|
|
||
| template <size_t D> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might not need to templatize BatchNormBaseOptions over D, because the arguments in the options is not using D :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with that. I removed template parameter.
| TORCH_ARG(bool, track_running_stats) = true; | ||
| }; | ||
|
|
||
| using BatchNorm1dOptions = BatchNormBaseOptions<1>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and then we can just write using BatchNorm1dOptions = BatchNormBaseOptions; :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I fixed it.
| running_mean = this->register_buffer("running_mean", Tensor()); | ||
| running_var = this->register_buffer("running_var", Tensor()); | ||
| num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor()); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can move all of the initialization logic here to reset(), and the constructor can just call reset(), which is consistent with how other C++ layers behave :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved all logic to reset method, and deleted reset_parameters method and reset_track_running_stats method.
| if (options.affine()) { | ||
| torch::nn::init::ones_(weight); | ||
| torch::nn::init::zeros_(bias); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove the reset_parameters() function and move the logic into reset() (after the initialization logic), which is consistent with how other C++ layers behave :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved all logic to reset method, and deleted reset_parameters method and reset_track_running_stats method.
| void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const { | ||
| stream << std::boolalpha | ||
| << "torch::nn::BatchNorm" << D << "d(" | ||
| << "num_features=" << options.num_features() << ", " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove the printing of num_features= here, to match the Python version even better :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed it.
| const Tensor& bias, bool training, | ||
| double momentum, double eps) { | ||
| if (training) { | ||
| std::vector<int64_t> size = input.sizes().vec(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect we could do
| std::vector<int64_t> size = input.sizes().vec(); | |
| auto size = input.sizes(); |
under the hood it uses ArrayRef as type for size :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I followed your instruction.
| training, | ||
| momentum, | ||
| eps, | ||
| torch::cuda::cudnn_is_available()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the C++ equivalent of torch.backends.cudnn.enabled is at::globalContext().userEnabledCuDNN()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I replaced torch.backends.cudnn.enabled to at::globalContext().userEnabledCuDNN().
| #include <torch/nn/functional/normalization.h> | ||
| #include <torch/nn/functional/pooling.h> | ||
| #include <torch/nn/functional/vision.h> | ||
| #include <torch/nn/functional/batchnorm.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be awesome to put it above #include <torch/nn/functional/distance.h>, to sort them alphabetically :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I followed your instruction.
|
It would be awesome to update the |
|
Thanks for your instructions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nuka137 Thanks so much for the awesome work! I took the liberty to merge BatchNormBaseOptions into BatchNormOptions because after thinking about it more I feel that forcing F::batch_norm to use BatchNormBaseOptions might not be a good idea, aesthetically speaking. Instead I renamed features to num_features and stateful to track_running_stats in BatchNormOptions, which is backward-compatibility breaking, but I think is better for the long term.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Add torch::nn::BatchNorm1d function/module support for the C++ API.
torch::nn::BatchNorm{2,3}d will be added after this PR is merged.
Related Issue: #25883
Reviewer: @yf225
I would like to discuss about below items.
num_batches_trackedinBatchNormImplBasenum_batches_trackedis needed to calculatemomentumwhen we do not feedmomentumargument in Python API. But in C++ API,momentumargument has a default value.num_batches_trackedis only used for counting upBatchNorm1d::foward()call. I think it is no necessary for user anymore.BatchNorm{1,2,3}dOptionsBatchNormOptionsused for deprecatedBatchNormmodule. However, it is hard to use it forBatchNorm{1,2,3}dOptionsbecause of the arguments disagreement of each modules.BatchNormOptionsv2template class for theBatchNorm{1,2,3}dOptions. But I'm not sure this design is good or not.