Skip to content

Conversation

@divyanshsinghvi
Copy link
Contributor

@divyanshsinghvi divyanshsinghvi commented Oct 28, 2019

Hi @yf225,

I have a few doubts related to implementation:

  1. What tests do I have to write?
  2. What does _load_state_from_dict does?
  3. Do I need to override reset() function as I can not see it's utility?
  4. InstanceNormOptions could be removed with BatchNormOptions, but I find that
    track_running_status is not defined instead stateful is defined.

InstanceNorm{1,2,3}d #25883

@ssnl ssnl changed the title Instance norm #25883 [cpp api] Instance norm #25883 Oct 28, 2019
@ssnl
Copy link
Collaborator

ssnl commented Oct 28, 2019

_load_state_from_dict is used to resolve BC issue for loading old python instance norm state dict, because the state dict of instance norm changed at one point. I believe that it is not relevant for CPP API.

@divyanshsinghvi divyanshsinghvi changed the title [cpp api] Instance norm #25883 [cpp api] Instance norm (#25883) Oct 28, 2019
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyanshsinghvi Thanks so much for the awesome work! I believe we will need to rebase this on top of #28176 after it's merged, because #28176 fixes the track_running_status vs. stateful issue that you mentioned, and also creates BatchNormImplBase which we'll need to subclass InstanceNormImpl from.

namespace torch {
namespace nn {

/// Base class for all (dimesnsion-specialized) instanceNorm modules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
/// Base class for all (dimesnsion-specialized) instanceNorm modules
/// Base class for all (dimension-specialized) InstanceNorm modules


Tensor forward(const Tensor& input);

/// The optons with which this `Module` was constructed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
/// The optons with which this `Module` was constructed.
/// The options with which this `Module` was constructed.

@yf225
Copy link
Contributor

yf225 commented Oct 29, 2019

For this PR we shouldn't need to override reset - we will get more clarity about it once we are able to rebase on top of #28176 :D

@yf225
Copy link
Contributor

yf225 commented Oct 30, 2019

@divyanshsinghvi I just rebased this PR on top of upstream master which contains the changes in #28176. I think we can let InstanceNormImpl subclass from the new BatchNormImplBase class, which mirrors the class hierarchy in Python. I also think we do need to define a separate InstanceNormOptions struct, because InstanceNorm constructor's affine and track_running_stats parameters take a different default value than the BatchNorm constructor's affine and track_running_stats parameters.

@yf225
Copy link
Contributor

yf225 commented Oct 30, 2019

I think it should be sufficient to provide an empty override for the reset() function, because I believe not overriding it would cause InstanceNormImpl to become an abstract class, and we cannot create objects of abstract classes.

@yf225
Copy link
Contributor

yf225 commented Oct 30, 2019

For tests, I think we don't need to be comprehensive and can just check for common input / output. This is an example for F::batch_norm

TEST_F(FunctionalTest, BatchNorm1d) {
int num_features = 5;
double eps = 1e-05;
double momentum = 0.1;
auto input = torch::randn({2, 5});
auto mean = torch::randn(5);
auto variance = torch::rand(5);
auto weight = torch::ones({num_features});
auto bias = torch::zeros({num_features});
auto output = F::batch_norm(
input, mean, variance,
BatchNormOptions().weight(weight).bias(bias).momentum(momentum).eps(eps),
/*training=*/false);
auto expected = (input - mean) / torch::sqrt(variance + eps);
ASSERT_TRUE(output.allclose(expected));
}
TEST_F(FunctionalTest, BatchNorm1dDefaultOptions) {
auto input = torch::randn({2, 5});
auto mean = torch::randn(5);
auto variance = torch::rand(5);
auto output = F::batch_norm(input, mean, variance);
auto expected = (input - mean) / torch::sqrt(variance + 1e-5);
ASSERT_TRUE(output.allclose(expected));
}
and we can do the same for the module tests as well
TEST_F(ModulesTest, BatchNorm1d) {
BatchNorm1d bn(BatchNorm1dOptions(5));
bn->eval();
auto input = torch::randn({2, 5}, torch::requires_grad());
auto output = bn->forward(input);
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
}

@divyanshsinghvi
Copy link
Contributor Author

I think it should be sufficient to provide an empty override for the reset() function, because I believe not overriding it would cause InstanceNormImpl to become an abstract class, and we cannot create objects of abstract classes.

Sorry I might be wrong in my understanding here. But since we are using BatchNormImplBase to subclass from, it would have already overridden the reset function. Hence, InstanceNormImpl shouldn't be an abstract class due to this. Though, it is likely the case that we would not like the function of reset defined in BatchNormImplBase getting called and hence we should redefine it

@yf225
Copy link
Contributor

yf225 commented Nov 1, 2019

@divyanshsinghvi Thanks a lot for the investigation! Yes I think you are right and we don't need to override reset in InstanceNormImpl, because it's already overridden in BatchNormImplBase.

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyanshsinghvi Thanks so much for the awesome work! This work is highly complex and thanks so much for your help on this. I left some comments that we can work through together.


/// Base class for all (dimension-specialized) instanceNorm modules
template <size_t D, typename Derived, typename BatchNormDerived>
class TORCH_API InstanceNormImpl : public torch::nn::BatchNormImplBase<D, BatchNormDerived> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that we can do

template <size_t D, typename Derived>
class TORCH_API InstanceNormImpl : public torch::nn::BatchNormImplBase<D, Derived> {

InstanceNormOptions options;
};

class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl, BatchNorm1dImpl> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then here we can write

class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> {

class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl, BatchNorm1dImpl> {
public:
using InstanceNormImpl<1, InstanceNorm1dImpl, BatchNorm1dImpl>::InstanceNormImpl;
private:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it correct that this will prevent users from subclassing InstanceNorm1dImpl and overriding _check_input_dim in their custom subclass? If so I think we should make it protected instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, right.


void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
TORCH_CHECK(
input.dim() == 2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be

Suggested change
input.dim() == 2,
input.dim() != 2,

because TORCH_CHECK prints the message when the condition is false :D

"(1, N * C, ...) from (N, C,...) and this makes",
"variances 0.");
TORCH_CHECK(
input.dim() != 3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be

Suggested change
input.dim() != 3,
input.dim() == 3,

because TORCH_CHECK prints the message when the condition is false :D

"variances 0.");
TORCH_CHECK(
input.dim() != 3,
"expected 3D input (got", input.dim(), "D input)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
"expected 3D input (got", input.dim(), "D input)");
"expected 3D input (got ", input.dim(), "D input)");


template <size_t D, typename Derived, typename BatchNormDerived>
InstanceNormImpl<D, Derived, BatchNormDerived>::InstanceNormImpl(const InstanceNormOptions& options_)
: BatchNormImplBase<D, BatchNormDerived>(BatchNormOptions(options_.num_features())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to write

Suggested change
: BatchNormImplBase<D, BatchNormDerived>(BatchNormOptions(options_.num_features())),
: BatchNormImplBase<D, BatchNormDerived>(BatchNormOptions(options_.num_features()).eps(options.eps()).momentum(options.momentum()).affine(options.affine()).track_running_stats(options.track_running_stats())),

otherwise some of the InstanceNormOptions parameters are not passed to the BatchNormImplBase constructor :D

@yf225 yf225 force-pushed the instanceNorm branch 9 times, most recently from 8dda49e to caa5ecc Compare November 19, 2019 06:16
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyanshsinghvi Thanks a lot for the awesome work!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in ec52d91.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpp Related to C++ API

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants