-
Notifications
You must be signed in to change notification settings - Fork 18.5k
I think MVNLayer with across_channels = true is broken #1938
Description
I think MVNLayer is broken when across_channels: true is specified in MVNParameter, even though the unit test is passing.
In MVNLayerTest::TestForwardAcrossChannels, we create the layer parameter with:
layer_param.ParseFromString("mvn_param{across_channels: true}"); ParseFromString is returning false and layer_param.mvn_param().across_channels() is false, after the call, even though we specified true in the prototxt. So the test is running with default value of across_channels, which is false, the default value.
If instead in MVNLayerTest::TestForwardAcrossChannels we call:
CHECK(google::protobuf::TextFormat::ParseFromString("mvn_param{across_channels: true}", &layer_param ));
it succeeds, and layer_param.mvn_param().across_channels() is true. And then you will see that MVNLayerTest::TestForwardAcrossChannels fails, as the Forward_* produces NaN values.
I have no idea why layer_param.ParseFromString fails whereas `google::protobuf::TextFormat::ParseFromString' suceeds. I don't have protocol source installed and would rather debug Caffe than Protobuffers. If anyone knows please educate me.
It at first seems surprising that test TestForwardAcrossChannels succeeds even though across_channels = false when you'd think it would need to be true in order to pass the test. I think it's because the values in bottom blob are drawn from a normal distribution using GaussianFiller, and the only difference between across_channels=true vs. false, is that there are 20 values from which one is computing the mean and variance instead of 60. So you still get a zero mean and variance 1.0 in the top blob which is what the test checks for, regardless of whether you compute them for a single channel or across channels.
I'm working on this in connection with 1895. I will say trying to understand the vectorization approach using BLAS calls used to perform these ops is making my head hurt : /.