Skip to content

Conversation

@nuka137
Copy link
Contributor

@nuka137 nuka137 commented Oct 30, 2019

Add torch::nn::BatchNorm{2,3}d module and functional support for the C++ API.

Related Issue: #25883 #28176

Reviewer: @yf225

@yf225 yf225 added the module: cpp Related to C++ API label Nov 1, 2019
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nuka137 Thanks so much for the awesome work and thorough tests! I have one very minor comment.

"expected 5D input (got ", input.dim(), "D input)");
}

template class BatchNormImplBase<3, BatchNorm3dImpl>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From code organization point of view, it would be awesome to have all template specializations next to each other for easier discovery:

template class BatchNormImplBase<1, BatchNorm1dImpl>;
template class BatchNormImplBase<2, BatchNorm2dImpl>;
template class BatchNormImplBase<3, BatchNorm3dImpl>;

Thanks so much for your help!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed it.

@nuka137
Copy link
Contributor Author

nuka137 commented Nov 1, 2019

@yf225

Thanks for your review. I fixed the source code you commented.
Could you check it again?

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nuka137 Thanks so much for the awesome work!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Nov 1, 2019

TEsts are failing:


Nov 01 17:32:00 [ RUN      ] ModulesTest.BatchNorm2dStateful
Nov 01 17:32:00 /var/lib/jenkins/workspace/test/cpp/api/modules.cpp:1224: Failure
Nov 01 17:32:00 Expected equality of these values:
Nov 01 17:32:00   bn->num_batches_tracked.dim()
Nov 01 17:32:00     Which is: 0
Nov 01 17:32:00   1

@yf225
Copy link
Contributor

yf225 commented Nov 1, 2019

@nuka137 Sorry it was my fault that causes this PR to be reverted - I will re-open this PR and make fixes directly to this PR, and then makes sure it passes CI before merging again.

@yf225 yf225 reopened this Nov 1, 2019
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in a68c1e1.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Nov 2, 2019

Can we split aten function to aten::batch_norm_1d, aten::batch_norm_2d and aten::batch_norm_3d and move the dimension check inside the aten function? This will simplify the other works that needs to match batchnorm module or op. cc @ZolotukhinM

@yf225
Copy link
Contributor

yf225 commented Nov 5, 2019

@jerryzh168 Yes I think that would be beneficial - the C++ API implementation here is just strictly following Python API implementation, and we can change it once the Python implementation is changed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpp Related to C++ API

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants