Skip to content

Commit dee01c8

Browse files
committed
test_net.cpp: add TestForcePropagateDown
1 parent 389db96 commit dee01c8

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

src/caffe/test/test_net.cpp

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest<TypeParam> {
716716
InitNetFromProtoString(proto);
717717
}
718718

719+
virtual void InitForcePropNet(bool test_force_true) {
720+
string proto =
721+
"name: 'ForcePropTestNetwork' "
722+
"layer { "
723+
" name: 'data' "
724+
" type: 'DummyData' "
725+
" dummy_data_param { "
726+
" shape { "
727+
" dim: 5 "
728+
" dim: 2 "
729+
" dim: 3 "
730+
" dim: 4 "
731+
" } "
732+
" data_filler { "
733+
" type: 'gaussian' "
734+
" std: 0.01 "
735+
" } "
736+
" shape { "
737+
" dim: 5 "
738+
" } "
739+
" data_filler { "
740+
" type: 'constant' "
741+
" value: 0 "
742+
" } "
743+
" } "
744+
" top: 'data' "
745+
" top: 'label' "
746+
"} "
747+
"layer { "
748+
" name: 'innerproduct' "
749+
" type: 'InnerProduct' "
750+
" inner_product_param { "
751+
" num_output: 1 "
752+
" weight_filler { "
753+
" type: 'gaussian' "
754+
" std: 0.01 "
755+
" } "
756+
" } "
757+
" bottom: 'data' "
758+
" top: 'innerproduct' ";
759+
if (test_force_true) {
760+
proto += " propagate_down: true ";
761+
}
762+
proto +=
763+
"} "
764+
"layer { "
765+
" name: 'loss' "
766+
" bottom: 'innerproduct' "
767+
" bottom: 'label' "
768+
" top: 'cross_entropy_loss' "
769+
" type: 'SigmoidCrossEntropyLoss' "
770+
"} ";
771+
InitNetFromProtoString(proto);
772+
}
773+
719774
int seed_;
720775
shared_ptr<Net<Dtype> > net_;
721776
};
@@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
23712426
}
23722427
}
23732428

2429+
TYPED_TEST(NetTest, TestForcePropagateDown) {
2430+
this->InitForcePropNet(false);
2431+
vector<bool> layer_need_backward = this->net_->layer_need_backward();
2432+
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
2433+
const string& layer_name = this->net_->layer_names()[layer_id];
2434+
const vector<bool> need_backward =
2435+
this->net_->bottom_need_backward()[layer_id];
2436+
if (layer_name == "data") {
2437+
ASSERT_EQ(need_backward.size(), 0);
2438+
EXPECT_FALSE(layer_need_backward[layer_id]);
2439+
} else if (layer_name == "innerproduct") {
2440+
ASSERT_EQ(need_backward.size(), 1);
2441+
EXPECT_FALSE(need_backward[0]); // data
2442+
EXPECT_TRUE(layer_need_backward[layer_id]);
2443+
} else if (layer_name == "loss") {
2444+
ASSERT_EQ(need_backward.size(), 2);
2445+
EXPECT_TRUE(need_backward[0]); // innerproduct
2446+
EXPECT_FALSE(need_backward[1]); // label
2447+
EXPECT_TRUE(layer_need_backward[layer_id]);
2448+
} else {
2449+
LOG(FATAL) << "Unknown layer: " << layer_name;
2450+
}
2451+
}
2452+
this->InitForcePropNet(true);
2453+
layer_need_backward = this->net_->layer_need_backward();
2454+
for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
2455+
const string& layer_name = this->net_->layer_names()[layer_id];
2456+
const vector<bool> need_backward =
2457+
this->net_->bottom_need_backward()[layer_id];
2458+
if (layer_name == "data") {
2459+
ASSERT_EQ(need_backward.size(), 0);
2460+
EXPECT_FALSE(layer_need_backward[layer_id]);
2461+
} else if (layer_name == "innerproduct") {
2462+
ASSERT_EQ(need_backward.size(), 1);
2463+
EXPECT_TRUE(need_backward[0]); // data
2464+
EXPECT_TRUE(layer_need_backward[layer_id]);
2465+
} else if (layer_name == "loss") {
2466+
ASSERT_EQ(need_backward.size(), 2);
2467+
EXPECT_TRUE(need_backward[0]); // innerproduct
2468+
EXPECT_FALSE(need_backward[1]); // label
2469+
EXPECT_TRUE(layer_need_backward[layer_id]);
2470+
} else {
2471+
LOG(FATAL) << "Unknown layer: " << layer_name;
2472+
}
2473+
}
2474+
}
2475+
23742476
} // namespace caffe

0 commit comments

Comments
 (0)