@@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest<TypeParam> {
716716 InitNetFromProtoString (proto);
717717 }
718718
719+ virtual void InitForcePropNet (bool test_force_true) {
720+ string proto =
721+ " name: 'ForcePropTestNetwork' "
722+ " layer { "
723+ " name: 'data' "
724+ " type: 'DummyData' "
725+ " dummy_data_param { "
726+ " shape { "
727+ " dim: 5 "
728+ " dim: 2 "
729+ " dim: 3 "
730+ " dim: 4 "
731+ " } "
732+ " data_filler { "
733+ " type: 'gaussian' "
734+ " std: 0.01 "
735+ " } "
736+ " shape { "
737+ " dim: 5 "
738+ " } "
739+ " data_filler { "
740+ " type: 'constant' "
741+ " value: 0 "
742+ " } "
743+ " } "
744+ " top: 'data' "
745+ " top: 'label' "
746+ " } "
747+ " layer { "
748+ " name: 'innerproduct' "
749+ " type: 'InnerProduct' "
750+ " inner_product_param { "
751+ " num_output: 1 "
752+ " weight_filler { "
753+ " type: 'gaussian' "
754+ " std: 0.01 "
755+ " } "
756+ " } "
757+ " bottom: 'data' "
758+ " top: 'innerproduct' " ;
759+ if (test_force_true) {
760+ proto += " propagate_down: true " ;
761+ }
762+ proto +=
763+ " } "
764+ " layer { "
765+ " name: 'loss' "
766+ " bottom: 'innerproduct' "
767+ " bottom: 'label' "
768+ " top: 'cross_entropy_loss' "
769+ " type: 'SigmoidCrossEntropyLoss' "
770+ " } " ;
771+ InitNetFromProtoString (proto);
772+ }
773+
719774 int seed_;
720775 shared_ptr<Net<Dtype> > net_;
721776};
@@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
23712426 }
23722427}
23732428
2429+ TYPED_TEST (NetTest, TestForcePropagateDown) {
2430+ this ->InitForcePropNet (false );
2431+ vector<bool > layer_need_backward = this ->net_ ->layer_need_backward ();
2432+ for (int layer_id = 0 ; layer_id < this ->net_ ->layers ().size (); ++layer_id) {
2433+ const string& layer_name = this ->net_ ->layer_names ()[layer_id];
2434+ const vector<bool > need_backward =
2435+ this ->net_ ->bottom_need_backward ()[layer_id];
2436+ if (layer_name == " data" ) {
2437+ ASSERT_EQ (need_backward.size (), 0 );
2438+ EXPECT_FALSE (layer_need_backward[layer_id]);
2439+ } else if (layer_name == " innerproduct" ) {
2440+ ASSERT_EQ (need_backward.size (), 1 );
2441+ EXPECT_FALSE (need_backward[0 ]); // data
2442+ EXPECT_TRUE (layer_need_backward[layer_id]);
2443+ } else if (layer_name == " loss" ) {
2444+ ASSERT_EQ (need_backward.size (), 2 );
2445+ EXPECT_TRUE (need_backward[0 ]); // innerproduct
2446+ EXPECT_FALSE (need_backward[1 ]); // label
2447+ EXPECT_TRUE (layer_need_backward[layer_id]);
2448+ } else {
2449+ LOG (FATAL) << " Unknown layer: " << layer_name;
2450+ }
2451+ }
2452+ this ->InitForcePropNet (true );
2453+ layer_need_backward = this ->net_ ->layer_need_backward ();
2454+ for (int layer_id = 0 ; layer_id < this ->net_ ->layers ().size (); ++layer_id) {
2455+ const string& layer_name = this ->net_ ->layer_names ()[layer_id];
2456+ const vector<bool > need_backward =
2457+ this ->net_ ->bottom_need_backward ()[layer_id];
2458+ if (layer_name == " data" ) {
2459+ ASSERT_EQ (need_backward.size (), 0 );
2460+ EXPECT_FALSE (layer_need_backward[layer_id]);
2461+ } else if (layer_name == " innerproduct" ) {
2462+ ASSERT_EQ (need_backward.size (), 1 );
2463+ EXPECT_TRUE (need_backward[0 ]); // data
2464+ EXPECT_TRUE (layer_need_backward[layer_id]);
2465+ } else if (layer_name == " loss" ) {
2466+ ASSERT_EQ (need_backward.size (), 2 );
2467+ EXPECT_TRUE (need_backward[0 ]); // innerproduct
2468+ EXPECT_FALSE (need_backward[1 ]); // label
2469+ EXPECT_TRUE (layer_need_backward[layer_id]);
2470+ } else {
2471+ LOG (FATAL) << " Unknown layer: " << layer_name;
2472+ }
2473+ }
2474+ }
2475+
23742476} // namespace caffe
0 commit comments