@@ -1069,6 +1069,72 @@ TEST_F(ModulesTest, Dropout) {
10691069 ASSERT_EQ (y.sum ().item <float >(), 100 );
10701070}
10711071
1072+ TEST_F (ModulesTest, Dropout2d) {
1073+ Dropout2d dropout (0.5 );
1074+ torch::Tensor x = torch::ones ({10 , 10 }, torch::requires_grad ());
1075+ torch::Tensor y = dropout (x);
1076+
1077+ y.backward (torch::ones_like (y));
1078+ ASSERT_EQ (y.ndimension (), 2 );
1079+ ASSERT_EQ (y.size (0 ), 10 );
1080+ ASSERT_EQ (y.size (1 ), 10 );
1081+ ASSERT_LT (y.sum ().item <float >(), 130 ); // Probably
1082+ ASSERT_GT (y.sum ().item <float >(), 70 ); // Probably
1083+
1084+ dropout->eval ();
1085+ y = dropout (x);
1086+ ASSERT_EQ (y.sum ().item <float >(), 100 );
1087+ }
1088+
1089+ TEST_F (ModulesTest, Dropout3d) {
1090+ Dropout3d dropout (0.5 );
1091+ torch::Tensor x = torch::ones ({4 , 5 , 5 }, torch::requires_grad ());
1092+ torch::Tensor y = dropout (x);
1093+
1094+ y.backward (torch::ones_like (y));
1095+ ASSERT_EQ (y.ndimension (), 3 );
1096+ ASSERT_EQ (y.size (0 ), 4 );
1097+ ASSERT_EQ (y.size (1 ), 5 );
1098+ ASSERT_EQ (y.size (1 ), 5 );
1099+ ASSERT_LT (y.sum ().item <float >(), 130 ); // Probably
1100+ ASSERT_GT (y.sum ().item <float >(), 70 ); // Probably
1101+
1102+ dropout->eval ();
1103+ y = dropout (x);
1104+ ASSERT_EQ (y.sum ().item <float >(), 100 );
1105+ }
1106+
1107+ TEST_F (ModulesTest, FeatureDropout) {
1108+ FeatureDropout dropout (0.5 );
1109+ torch::Tensor x = torch::ones ({10 , 10 }, torch::requires_grad ());
1110+ torch::Tensor y = dropout (x);
1111+
1112+ y.backward (torch::ones_like (y));
1113+ ASSERT_EQ (y.ndimension (), 2 );
1114+ ASSERT_EQ (y.size (0 ), 10 );
1115+ ASSERT_EQ (y.size (1 ), 10 );
1116+ ASSERT_LT (y.sum ().item <float >(), 130 ); // Probably
1117+ ASSERT_GT (y.sum ().item <float >(), 70 ); // Probably
1118+
1119+ dropout->eval ();
1120+ y = dropout (x);
1121+ ASSERT_EQ (y.sum ().item <float >(), 100 );
1122+ }
1123+
1124+ TEST_F (ModulesTest, FeatureDropoutLegacyWarning) {
1125+ std::stringstream buffer;
1126+ torch::test::CerrRedirect cerr_redirect (buffer.rdbuf ());
1127+
1128+ FeatureDropout bn (0.5 );
1129+
1130+ ASSERT_EQ (
1131+ count_substr_occurrences (
1132+ buffer.str (),
1133+ " torch::nn::FeatureDropout module is deprecated"
1134+ ),
1135+ 1 );
1136+ }
1137+
10721138TEST_F (ModulesTest, Parameters) {
10731139 auto model = std::make_shared<NestedModel>();
10741140 auto parameters = model->named_parameters ();
@@ -2780,9 +2846,27 @@ TEST_F(ModulesTest, PrettyPrintMaxUnpool) {
27802846}
27812847
27822848TEST_F (ModulesTest, PrettyPrintDropout) {
2783- ASSERT_EQ (c10::str (Dropout (0.5 )), " torch::nn::Dropout(rate=0.5)" );
2784- ASSERT_EQ (
2785- c10::str (FeatureDropout (0.5 )), " torch::nn::FeatureDropout(rate=0.5)" );
2849+ ASSERT_EQ (c10::str (Dropout ()), " torch::nn::Dropout(p=0.5, inplace=false)" );
2850+ ASSERT_EQ (c10::str (Dropout (0.42 )), " torch::nn::Dropout(p=0.42, inplace=false)" );
2851+ ASSERT_EQ (c10::str (Dropout (DropoutOptions ().p (0.42 ).inplace (true ))), " torch::nn::Dropout(p=0.42, inplace=true)" );
2852+ }
2853+
2854+ TEST_F (ModulesTest, PrettyPrintDropout2d) {
2855+ ASSERT_EQ (c10::str (Dropout2d ()), " torch::nn::Dropout2d(p=0.5, inplace=false)" );
2856+ ASSERT_EQ (c10::str (Dropout2d (0.42 )), " torch::nn::Dropout2d(p=0.42, inplace=false)" );
2857+ ASSERT_EQ (c10::str (Dropout2d (Dropout2dOptions ().p (0.42 ).inplace (true ))), " torch::nn::Dropout2d(p=0.42, inplace=true)" );
2858+ }
2859+
2860+ TEST_F (ModulesTest, PrettyPrintDropout3d) {
2861+ ASSERT_EQ (c10::str (Dropout3d ()), " torch::nn::Dropout3d(p=0.5, inplace=false)" );
2862+ ASSERT_EQ (c10::str (Dropout3d (0.42 )), " torch::nn::Dropout3d(p=0.42, inplace=false)" );
2863+ ASSERT_EQ (c10::str (Dropout3d (Dropout3dOptions ().p (0.42 ).inplace (true ))), " torch::nn::Dropout3d(p=0.42, inplace=true)" );
2864+ }
2865+
2866+ TEST_F (ModulesTest, PrettyPrintFeatureDropout) {
2867+ ASSERT_EQ (c10::str (FeatureDropout ()), " torch::nn::FeatureDropout(p=0.5, inplace=false)" );
2868+ ASSERT_EQ (c10::str (FeatureDropout (0.42 )), " torch::nn::FeatureDropout(p=0.42, inplace=false)" );
2869+ ASSERT_EQ (c10::str (FeatureDropout (FeatureDropoutOptions ().p (0.42 ).inplace (true ))), " torch::nn::FeatureDropout(p=0.42, inplace=true)" );
27862870}
27872871
27882872TEST_F (ModulesTest, PrettyPrintFunctional) {
0 commit comments