Skip to content

Commit 27afac2

Browse files
pbelevichfacebook-github-bot
authored andcommitted
C++ API parity: Dropout, Dropout2d, Dropout3d
Summary: Pull Request resolved: #29761 Test Plan: Imported from OSS Differential Revision: D18530820 Pulled By: pbelevich fbshipit-source-id: 9d351561692f7de099d7c6aaf2ecb930b5c867e9
1 parent fbabf72 commit 27afac2

File tree

12 files changed

+393
-68
lines changed

12 files changed

+393
-68
lines changed

test/cpp/api/functional.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1871,3 +1871,48 @@ TEST_F(FunctionalTest, MarginRankingLoss) {
18711871
));
18721872
}
18731873
}
1874+
1875+
TEST_F(FunctionalTest, Dropout) {
1876+
auto input = torch::randn(5000);
1877+
auto input_mean = input.mean();
1878+
auto input_std = input.std();
1879+
1880+
for (const auto rate : {0.2, 0.5, 0.8}) {
1881+
auto output = F::dropout(input, F::DropoutFuncOptions().p(rate));
1882+
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
1883+
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
1884+
}
1885+
auto output = F::dropout(input);
1886+
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
1887+
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
1888+
}
1889+
1890+
TEST_F(FunctionalTest, Dropout2d) {
1891+
auto input = torch::randn({50, 100});
1892+
auto input_mean = input.mean();
1893+
auto input_std = input.std();
1894+
1895+
for (const auto rate : {0.2, 0.5, 0.8}) {
1896+
auto output = F::dropout2d(input, F::Dropout2dFuncOptions().p(rate));
1897+
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
1898+
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
1899+
}
1900+
auto output = F::dropout2d(input);
1901+
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
1902+
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
1903+
}
1904+
1905+
TEST_F(FunctionalTest, Dropout3d) {
1906+
auto input = torch::randn({50, 10, 10});
1907+
auto input_mean = input.mean();
1908+
auto input_std = input.std();
1909+
1910+
for (const auto rate : {0.2, 0.5, 0.8}) {
1911+
auto output = F::dropout3d(input, F::Dropout3dFuncOptions().p(rate));
1912+
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
1913+
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
1914+
}
1915+
auto output = F::dropout3d(input);
1916+
ASSERT_TRUE(torch::allclose(input_mean, output.mean(), 0.01, 0.05));
1917+
ASSERT_TRUE((input_std <= output.std()).all().item<bool>());
1918+
}

test/cpp/api/integration.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ TEST_F(IntegrationTest, MNIST_CUDA) {
250250
auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
251251
auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
252252
auto drop = Dropout(0.3);
253-
auto drop2d = FeatureDropout(0.3);
253+
auto drop2d = Dropout2d(0.3);
254254
auto linear1 = model->add(Linear(320, 50), "linear1");
255255
auto linear2 = model->add(Linear(50, 10), "linear2");
256256

test/cpp/api/modulelist.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,12 +240,12 @@ TEST_F(ModuleListTest, IsCloneable) {
240240
}
241241

242242
TEST_F(ModuleListTest, RegistersElementsAsSubmodules) {
243-
ModuleList list(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
243+
ModuleList list(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
244244

245245
auto modules = list->children();
246246
ASSERT_TRUE(modules[0]->as<Linear>());
247247
ASSERT_TRUE(modules[1]->as<Conv2d>());
248-
ASSERT_TRUE(modules[2]->as<FeatureDropout>());
248+
ASSERT_TRUE(modules[2]->as<Dropout2d>());
249249
}
250250

251251
TEST_F(ModuleListTest, NestingIsPossible) {
@@ -280,7 +280,7 @@ TEST_F(ModuleListTest, PrettyPrintModuleList) {
280280
"torch::nn::ModuleList(\n"
281281
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
282282
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
283-
" (2): torch::nn::Dropout(rate=0.5)\n"
283+
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
284284
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
285285
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
286286
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"

test/cpp/api/modules.cpp

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,72 @@ TEST_F(ModulesTest, Dropout) {
10691069
ASSERT_EQ(y.sum().item<float>(), 100);
10701070
}
10711071

1072+
TEST_F(ModulesTest, Dropout2d) {
1073+
Dropout2d dropout(0.5);
1074+
torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
1075+
torch::Tensor y = dropout(x);
1076+
1077+
y.backward(torch::ones_like(y));
1078+
ASSERT_EQ(y.ndimension(), 2);
1079+
ASSERT_EQ(y.size(0), 10);
1080+
ASSERT_EQ(y.size(1), 10);
1081+
ASSERT_LT(y.sum().item<float>(), 130); // Probably
1082+
ASSERT_GT(y.sum().item<float>(), 70); // Probably
1083+
1084+
dropout->eval();
1085+
y = dropout(x);
1086+
ASSERT_EQ(y.sum().item<float>(), 100);
1087+
}
1088+
1089+
TEST_F(ModulesTest, Dropout3d) {
1090+
Dropout3d dropout(0.5);
1091+
torch::Tensor x = torch::ones({4, 5, 5}, torch::requires_grad());
1092+
torch::Tensor y = dropout(x);
1093+
1094+
y.backward(torch::ones_like(y));
1095+
ASSERT_EQ(y.ndimension(), 3);
1096+
ASSERT_EQ(y.size(0), 4);
1097+
ASSERT_EQ(y.size(1), 5);
1098+
ASSERT_EQ(y.size(1), 5);
1099+
ASSERT_LT(y.sum().item<float>(), 130); // Probably
1100+
ASSERT_GT(y.sum().item<float>(), 70); // Probably
1101+
1102+
dropout->eval();
1103+
y = dropout(x);
1104+
ASSERT_EQ(y.sum().item<float>(), 100);
1105+
}
1106+
1107+
TEST_F(ModulesTest, FeatureDropout) {
1108+
FeatureDropout dropout(0.5);
1109+
torch::Tensor x = torch::ones({10, 10}, torch::requires_grad());
1110+
torch::Tensor y = dropout(x);
1111+
1112+
y.backward(torch::ones_like(y));
1113+
ASSERT_EQ(y.ndimension(), 2);
1114+
ASSERT_EQ(y.size(0), 10);
1115+
ASSERT_EQ(y.size(1), 10);
1116+
ASSERT_LT(y.sum().item<float>(), 130); // Probably
1117+
ASSERT_GT(y.sum().item<float>(), 70); // Probably
1118+
1119+
dropout->eval();
1120+
y = dropout(x);
1121+
ASSERT_EQ(y.sum().item<float>(), 100);
1122+
}
1123+
1124+
TEST_F(ModulesTest, FeatureDropoutLegacyWarning) {
1125+
std::stringstream buffer;
1126+
torch::test::CerrRedirect cerr_redirect(buffer.rdbuf());
1127+
1128+
FeatureDropout bn(0.5);
1129+
1130+
ASSERT_EQ(
1131+
count_substr_occurrences(
1132+
buffer.str(),
1133+
"torch::nn::FeatureDropout module is deprecated"
1134+
),
1135+
1);
1136+
}
1137+
10721138
TEST_F(ModulesTest, Parameters) {
10731139
auto model = std::make_shared<NestedModel>();
10741140
auto parameters = model->named_parameters();
@@ -2780,9 +2846,27 @@ TEST_F(ModulesTest, PrettyPrintMaxUnpool) {
27802846
}
27812847

27822848
TEST_F(ModulesTest, PrettyPrintDropout) {
2783-
ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)");
2784-
ASSERT_EQ(
2785-
c10::str(FeatureDropout(0.5)), "torch::nn::FeatureDropout(rate=0.5)");
2849+
ASSERT_EQ(c10::str(Dropout()), "torch::nn::Dropout(p=0.5, inplace=false)");
2850+
ASSERT_EQ(c10::str(Dropout(0.42)), "torch::nn::Dropout(p=0.42, inplace=false)");
2851+
ASSERT_EQ(c10::str(Dropout(DropoutOptions().p(0.42).inplace(true))), "torch::nn::Dropout(p=0.42, inplace=true)");
2852+
}
2853+
2854+
TEST_F(ModulesTest, PrettyPrintDropout2d) {
2855+
ASSERT_EQ(c10::str(Dropout2d()), "torch::nn::Dropout2d(p=0.5, inplace=false)");
2856+
ASSERT_EQ(c10::str(Dropout2d(0.42)), "torch::nn::Dropout2d(p=0.42, inplace=false)");
2857+
ASSERT_EQ(c10::str(Dropout2d(Dropout2dOptions().p(0.42).inplace(true))), "torch::nn::Dropout2d(p=0.42, inplace=true)");
2858+
}
2859+
2860+
TEST_F(ModulesTest, PrettyPrintDropout3d) {
2861+
ASSERT_EQ(c10::str(Dropout3d()), "torch::nn::Dropout3d(p=0.5, inplace=false)");
2862+
ASSERT_EQ(c10::str(Dropout3d(0.42)), "torch::nn::Dropout3d(p=0.42, inplace=false)");
2863+
ASSERT_EQ(c10::str(Dropout3d(Dropout3dOptions().p(0.42).inplace(true))), "torch::nn::Dropout3d(p=0.42, inplace=true)");
2864+
}
2865+
2866+
TEST_F(ModulesTest, PrettyPrintFeatureDropout) {
2867+
ASSERT_EQ(c10::str(FeatureDropout()), "torch::nn::FeatureDropout(p=0.5, inplace=false)");
2868+
ASSERT_EQ(c10::str(FeatureDropout(0.42)), "torch::nn::FeatureDropout(p=0.42, inplace=false)");
2869+
ASSERT_EQ(c10::str(FeatureDropout(FeatureDropoutOptions().p(0.42).inplace(true))), "torch::nn::FeatureDropout(p=0.42, inplace=true)");
27862870
}
27872871

27882872
TEST_F(ModulesTest, PrettyPrintFunctional) {

test/cpp/api/sequential.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,12 @@ TEST_F(SequentialTest, IsCloneable) {
377377
}
378378

379379
TEST_F(SequentialTest, RegistersElementsAsSubmodules) {
380-
Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), FeatureDropout(0.5));
380+
Sequential sequential(Linear(10, 3), Conv2d(1, 2, 3), Dropout2d(0.5));
381381

382382
auto modules = sequential->children();
383383
ASSERT_TRUE(modules[0]->as<Linear>());
384384
ASSERT_TRUE(modules[1]->as<Conv2d>());
385-
ASSERT_TRUE(modules[2]->as<FeatureDropout>());
385+
ASSERT_TRUE(modules[2]->as<Dropout2d>());
386386
}
387387

388388
TEST_F(SequentialTest, CloneToDevice_CUDA) {
@@ -411,7 +411,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
411411
"torch::nn::Sequential(\n"
412412
" (0): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
413413
" (1): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
414-
" (2): torch::nn::Dropout(rate=0.5)\n"
414+
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
415415
" (3): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
416416
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
417417
" (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
@@ -430,7 +430,7 @@ TEST_F(SequentialTest, PrettyPrintSequential) {
430430
"torch::nn::Sequential(\n"
431431
" (linear): torch::nn::Linear(in_features=10, out_features=3, bias=true)\n"
432432
" (conv2d): torch::nn::Conv2d(1, 2, kernel_size=[3, 3], stride=[1, 1])\n"
433-
" (dropout): torch::nn::Dropout(rate=0.5)\n"
433+
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
434434
" (batchnorm): torch::nn::BatchNorm(num_features=5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
435435
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
436436
" (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"

test/cpp_api_parity/parity-tracker.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ torch.nn.Identity|Yes|No
9696
torch.nn.Linear|Yes|No
9797
torch.nn.Bilinear|Yes|No
9898
torch.nn.Flatten|Yes|No
99-
torch.nn.Dropout|No|No
100-
torch.nn.Dropout2d|No|No
101-
torch.nn.Dropout3d|No|No
99+
torch.nn.Dropout|Yes|No
100+
torch.nn.Dropout2d|Yes|No
101+
torch.nn.Dropout3d|Yes|No
102102
torch.nn.AlphaDropout|No|No
103103
torch.nn.Embedding|Yes|No
104104
torch.nn.EmbeddingBag|Yes|No

torch/csrc/api/include/torch/nn/functional.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <torch/nn/functional/batchnorm.h>
44
#include <torch/nn/functional/conv.h>
55
#include <torch/nn/functional/distance.h>
6+
#include <torch/nn/functional/dropout.h>
67
#include <torch/nn/functional/embedding.h>
78
#include <torch/nn/functional/fold.h>
89
#include <torch/nn/functional/linear.h>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#pragma once
2+
3+
#include <torch/nn/options/dropout.h>
4+
5+
namespace torch {
6+
namespace nn {
7+
namespace functional {
8+
9+
namespace detail {
10+
11+
inline Tensor dropout(Tensor& input, double p, bool training, bool inplace) {
12+
TORCH_CHECK(
13+
p >= 0. && p <= 1.,
14+
"dropout probability has to be between 0 and 1, but got ",
15+
p);
16+
if (inplace) {
17+
return torch::dropout_(input, p, training);
18+
} else {
19+
return torch::dropout(input, p, training);
20+
}
21+
}
22+
23+
} // namespace detail
24+
25+
inline Tensor dropout(Tensor& input,
26+
const DropoutFuncOptions& options = {}) {
27+
return detail::dropout(
28+
input, options.p(), options.training(), options.inplace());
29+
}
30+
31+
// ============================================================================
32+
33+
namespace detail {
34+
35+
inline Tensor dropout2d(Tensor& input, double p, bool training, bool inplace) {
36+
TORCH_CHECK(
37+
p >= 0. && p <= 1.,
38+
"dropout probability has to be between 0 and 1, but got ",
39+
p);
40+
if (inplace) {
41+
return torch::feature_dropout_(input, p, training);
42+
} else {
43+
return torch::feature_dropout(input, p, training);
44+
}
45+
}
46+
47+
} // namespace detail
48+
49+
inline Tensor dropout2d(Tensor& input,
50+
const Dropout2dFuncOptions& options = {}) {
51+
return detail::dropout2d(
52+
input, options.p(), options.training(), options.inplace());
53+
}
54+
55+
// ============================================================================
56+
57+
namespace detail {
58+
59+
inline Tensor dropout3d(Tensor& input, double p, bool training, bool inplace) {
60+
TORCH_CHECK(
61+
p >= 0. && p <= 1.,
62+
"dropout probability has to be between 0 and 1, but got ",
63+
p);
64+
if (inplace) {
65+
return torch::feature_dropout_(input, p, training);
66+
} else {
67+
return torch::feature_dropout(input, p, training);
68+
}
69+
}
70+
71+
} // namespace detail
72+
73+
inline Tensor dropout3d(Tensor& input,
74+
const Dropout3dFuncOptions& options = {}) {
75+
return detail::dropout3d(
76+
input, options.p(), options.training(), options.inplace());
77+
}
78+
79+
} // namespace functional
80+
} // namespace nn
81+
} // namespace torch

0 commit comments

Comments
 (0)