-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[C++ API] InstanceNorm{1,2,3}d #28790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
|
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyanshsinghvi Thanks so much for the awesome work! I believe we will need to rebase this on top of #28176 after it's merged, because #28176 fixes the track_running_status vs. stateful issue that you mentioned, and also creates BatchNormImplBase which we'll need to subclass InstanceNormImpl from.
| namespace torch { | ||
| namespace nn { | ||
|
|
||
| /// Base class for all (dimesnsion-specialized) instanceNorm modules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| /// Base class for all (dimesnsion-specialized) instanceNorm modules | |
| /// Base class for all (dimension-specialized) InstanceNorm modules |
|
|
||
| Tensor forward(const Tensor& input); | ||
|
|
||
| /// The optons with which this `Module` was constructed. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| /// The optons with which this `Module` was constructed. | |
| /// The options with which this `Module` was constructed. |
|
For this PR we shouldn't need to override |
|
@divyanshsinghvi I just rebased this PR on top of upstream master which contains the changes in #28176. I think we can let |
|
I think it should be sufficient to provide an empty override for the |
|
For tests, I think we don't need to be comprehensive and can just check for common input / output. This is an example for pytorch/test/cpp/api/functional.cpp Lines 1269 to 1294 in a8b63ca
pytorch/test/cpp/api/modules.cpp Lines 1136 to 1147 in a8b63ca
|
Sorry I might be wrong in my understanding here. But since we are using |
|
@divyanshsinghvi Thanks a lot for the investigation! Yes I think you are right and we don't need to override |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyanshsinghvi Thanks so much for the awesome work! This work is highly complex and thanks so much for your help on this. I left some comments that we can work through together.
|
|
||
| /// Base class for all (dimension-specialized) instanceNorm modules | ||
| template <size_t D, typename Derived, typename BatchNormDerived> | ||
| class TORCH_API InstanceNormImpl : public torch::nn::BatchNormImplBase<D, BatchNormDerived> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect that we can do
template <size_t D, typename Derived>
class TORCH_API InstanceNormImpl : public torch::nn::BatchNormImplBase<D, Derived> {| InstanceNormOptions options; | ||
| }; | ||
|
|
||
| class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl, BatchNorm1dImpl> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and then here we can write
class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> {| class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl, BatchNorm1dImpl> { | ||
| public: | ||
| using InstanceNormImpl<1, InstanceNorm1dImpl, BatchNorm1dImpl>::InstanceNormImpl; | ||
| private: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct that this will prevent users from subclassing InstanceNorm1dImpl and overriding _check_input_dim in their custom subclass? If so I think we should make it protected instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, right.
|
|
||
| void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) { | ||
| TORCH_CHECK( | ||
| input.dim() == 2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should be
| input.dim() == 2, | |
| input.dim() != 2, |
because TORCH_CHECK prints the message when the condition is false :D
| "(1, N * C, ...) from (N, C,...) and this makes", | ||
| "variances 0."); | ||
| TORCH_CHECK( | ||
| input.dim() != 3, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this should be
| input.dim() != 3, | |
| input.dim() == 3, |
because TORCH_CHECK prints the message when the condition is false :D
| "variances 0."); | ||
| TORCH_CHECK( | ||
| input.dim() != 3, | ||
| "expected 3D input (got", input.dim(), "D input)"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| "expected 3D input (got", input.dim(), "D input)"); | |
| "expected 3D input (got ", input.dim(), "D input)"); |
|
|
||
| template <size_t D, typename Derived, typename BatchNormDerived> | ||
| InstanceNormImpl<D, Derived, BatchNormDerived>::InstanceNormImpl(const InstanceNormOptions& options_) | ||
| : BatchNormImplBase<D, BatchNormDerived>(BatchNormOptions(options_.num_features())), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we need to write
| : BatchNormImplBase<D, BatchNormDerived>(BatchNormOptions(options_.num_features())), | |
| : BatchNormImplBase<D, BatchNormDerived>(BatchNormOptions(options_.num_features()).eps(options.eps()).momentum(options.momentum()).affine(options.affine()).track_running_stats(options.track_running_stats())), |
otherwise some of the InstanceNormOptions parameters are not passed to the BatchNormImplBase constructor :D
8dda49e to
caa5ecc
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyanshsinghvi Thanks a lot for the awesome work!
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Hi @yf225,
I have a few doubts related to implementation:
track_running_statusis not defined insteadstatefulis defined.InstanceNorm{1,2,3}d #25883