Skip to content

Conversation

@nuka137
Copy link
Contributor

@nuka137 nuka137 commented Oct 16, 2019

Add torch::nn::BatchNorm1d function/module support for the C++ API.
torch::nn::BatchNorm{2,3}d will be added after this PR is merged.

Related Issue: #25883

Reviewer: @yf225

I would like to discuss about below items.

  • Necessity of num_batches_tracked in BatchNormImplBase
    • num_batches_tracked is needed to calculate momentum when we do not feed momentum argument in Python API. But in C++ API, momentum argument has a default value.
    • num_batches_tracked is only used for counting up BatchNorm1d::foward() call. I think it is no necessary for user anymore.
  • The design of BatchNorm{1,2,3}dOptions
    • We have already BatchNormOptions used for deprecated BatchNorm module. However, it is hard to use it for BatchNorm{1,2,3}dOptions because of the arguments disagreement of each modules.
    • In this PR, I introduce BatchNormOptionsv2 template class for the BatchNorm{1,2,3}dOptions. But I'm not sure this design is good or not.

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nuka137 My sincere apologies for the delay and thanks so much for the investigation and the awesome work! I left some comments regarding the design choices. The scope of work is large for BatchNorm{1,2,3}d and thanks so much for working on it. :D


/// A momentum multiplier for the mean and variance.
/// Changing this parameter after construction __is effective__.
TORCH_ARG(double, momentum) = 0.1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Python version I believe we allow None as value for momentum, and to support it in C++ version I believe we'll need to use c10::optional<double> as type for momentum.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I fixed it.

};

template <size_t D>
struct BatchNormOptionsv2 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we can probably name it BatchNormBaseOptions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion. I use BatchNormBaseOptions instead of BatchNormOptionsv2 .


BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
LOG(WARNING) << "torch::nn::BatchNorm module is deprecated."
<< "Use BatchNorm{1,2,3}d instead.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think TORCH_WARN might be a better way to print the warning message, which is consistent with other parts of the C++ frontend :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that. Change to TORCH_WARN.

public:
using BatchNormImplBase<1, BatchNorm1dImpl>::BatchNormImplBase;

Tensor forward(const Tensor& input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we might be able to even move forward to the BatchNormImplBase class :D The Python version of forward seems to call F.batch_norm, and I think we can follow the same design and rename F::batch_norm1d to F::batch_norm to match the Python version even better :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I fixed it.
Also, I fixed the arguments of F::batch_norm too.

Tensor BatchNorm1dImpl::forward(const Tensor& input) {
TORCH_CHECK(
input.dim() != 2 && input.dim() !=3,
"expected 2D or 3D input (got %dD input)", input.dim());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that we can add a virtual function _check_input_dim to the BatchNormImplBase class, override its implementation in BatchNorm1dImpl, and call this _check_input_dim function from forward, to match the Python version even better :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I added pure virtual function _check_input_dim to BatchNormImplBase class.

@nuka137
Copy link
Contributor Author

nuka137 commented Oct 23, 2019

@yf225

Thanks for reviewing. I fixed all issues.
And I found that there is no PrettyPrintBatchNorm1d test in previous commit, so I added it in new commit.
Could you check it together?

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nuka137 Thanks so much and really appreciated the awesome work! I left some comments.

TORCH_ARG(double, momentum) = 0.1;
};

template <size_t D>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might not need to templatize BatchNormBaseOptions over D, because the arguments in the options is not using D :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that. I removed template parameter.

TORCH_ARG(bool, track_running_stats) = true;
};

using BatchNorm1dOptions = BatchNormBaseOptions<1>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then we can just write using BatchNorm1dOptions = BatchNormBaseOptions; :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I fixed it.

running_mean = this->register_buffer("running_mean", Tensor());
running_var = this->register_buffer("running_var", Tensor());
num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move all of the initialization logic here to reset(), and the constructor can just call reset(), which is consistent with how other C++ layers behave :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved all logic to reset method, and deleted reset_parameters method and reset_track_running_stats method.

if (options.affine()) {
torch::nn::init::ones_(weight);
torch::nn::init::zeros_(bias);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the reset_parameters() function and move the logic into reset() (after the initialization logic), which is consistent with how other C++ layers behave :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved all logic to reset method, and deleted reset_parameters method and reset_track_running_stats method.

void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::BatchNorm" << D << "d("
<< "num_features=" << options.num_features() << ", "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the printing of num_features= here, to match the Python version even better :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it.

const Tensor& bias, bool training,
double momentum, double eps) {
if (training) {
std::vector<int64_t> size = input.sizes().vec();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we could do

Suggested change
std::vector<int64_t> size = input.sizes().vec();
auto size = input.sizes();

under the hood it uses ArrayRef as type for size :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I followed your instruction.

training,
momentum,
eps,
torch::cuda::cudnn_is_available());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the C++ equivalent of torch.backends.cudnn.enabled is at::globalContext().userEnabledCuDNN()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced torch.backends.cudnn.enabled to at::globalContext().userEnabledCuDNN().

#include <torch/nn/functional/normalization.h>
#include <torch/nn/functional/pooling.h>
#include <torch/nn/functional/vision.h>
#include <torch/nn/functional/batchnorm.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be awesome to put it above #include <torch/nn/functional/distance.h>, to sort them alphabetically :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I followed your instruction.

@yf225
Copy link
Contributor

yf225 commented Oct 24, 2019

It would be awesome to update the BatchNorm1d entry in test/cpp_api_parity/parity-tracker.md as well. Thanks so much for the great work!

@nuka137
Copy link
Contributor Author

nuka137 commented Oct 25, 2019

@yf225

Thanks for your instructions.
I followed them except for templatize BatchNormImplBase.
I think BatchNormImplBase is still needed to output dimension in pretty_print.
What do you think about it?

@nuka137
Copy link
Contributor Author

nuka137 commented Oct 25, 2019

@yf225

Hi, I found that there is a change in PrettyPrintHardtanh test in this PR.
I believe that this is a accident from commit ada8d54 because PrettyPrintHardtanh test is failed now.
If you don't mind, I will fixed this in PR.

@yf225 yf225 added the module: cpp Related to C++ API label Oct 29, 2019
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nuka137 Thanks so much for the awesome work! I took the liberty to merge BatchNormBaseOptions into BatchNormOptions because after thinking about it more I feel that forcing F::batch_norm to use BatchNormBaseOptions might not be a good idea, aesthetically speaking. Instead I renamed features to num_features and stateful to track_running_stats in BatchNormOptions, which is backward-compatibility breaking, but I think is better for the long term.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in cbc234b.

facebook-github-bot pushed a commit that referenced this pull request Nov 1, 2019
Summary:
Add torch::nn::BatchNorm{2,3}d module and functional support for the C++ API.

Related Issue: #25883 #28176

Reviewer: yf225
Pull Request resolved: #28936

Differential Revision: D18266918

Pulled By: yf225

fbshipit-source-id: f432904c72985d52ec52cb992cceb372b6ff0244
facebook-github-bot pushed a commit that referenced this pull request Nov 2, 2019
Summary:
Add torch::nn::BatchNorm{2,3}d module and functional support for the C++ API.

Related Issue: #25883 #28176

Reviewer: yf225
Pull Request resolved: #28936

Differential Revision: D18274584

Pulled By: yf225

fbshipit-source-id: 3784eee9f8947f6c7c9f1699544a3d36a1a019b7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpp Related to C++ API

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants