Skip to content

Commit 62554a3

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
Prioritize raising error message about unused parameters when rebuild_buckets fails (#45933)
Summary: Pull Request resolved: #45933 Occasionally users run DDP with models with unused params, in this case we would like to surface an error message telling them to run with find_unused_params=True. However, a recent change to rebuild_buckets logic (#44798) made it so that we raise a size mismatch error when this happens, but the information about unused parameters is likely to be more useful and likely to be the most common case of failure. Prefer raising this error over the subsequent size mismatch errors. ghstack-source-id: 113914759 Test Plan: Added unittest Reviewed By: mrshenli Differential Revision: D24151256 fbshipit-source-id: 5d349a988b4aac7d3e0ef7b3cd84dfdcbe9db675
1 parent 9fb8e33 commit 62554a3

File tree

3 files changed

+68
-25
lines changed

3 files changed

+68
-25
lines changed

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -946,31 +946,6 @@ void Reducer::prepare_for_backward(
946946
std::unordered_set<torch::autograd::Node*> seen;
947947
std::vector<torch::autograd::Node*> queue;
948948

949-
// Check that any prior reduction has finished.
950-
// The variable `require_finalize_` is true until all gradients
951-
// have been computed and reduction of all buckets has been kicked off.
952-
if (require_finalize_) {
953-
TORCH_CHECK(
954-
false,
955-
"Expected to have finished reduction in the prior iteration before ",
956-
"starting a new one. ",
957-
"",
958-
"This error indicates that your module has parameters that were ",
959-
"not used in producing loss. ",
960-
"",
961-
"You can enable unused parameter detection by (1) passing the keyword "
962-
"argument `find_unused_parameters=True` to ",
963-
"`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ",
964-
"`forward` function outputs participate in calculating loss. "
965-
"",
966-
"If you already have done the above two steps, then the distributed ",
967-
"data parallel module wasn't able to locate the output tensors in the ",
968-
"return value of your module's `forward` function. ",
969-
"Please include the loss function and the structure of the return ",
970-
"value of `forward` of your module when reporting this issue (e.g. ",
971-
"list, dict, iterable).");
972-
}
973-
974949
// Reset accounting.
975950
expect_autograd_hooks_ = true;
976951
next_bucket_ = 0;
@@ -1325,6 +1300,11 @@ void Reducer::sync_bucket_indices(
13251300
}
13261301

13271302
bool Reducer::rebuild_buckets() {
1303+
// Ensure reduction for previous backwards pass is finished. If user's model
1304+
// has unused parameters for example, this will raise an error recommending to
1305+
// run with find_unused_parameters=True, instead of the size mismatch
1306+
// exception below.
1307+
ensure_prior_reduction_finished();
13281308
std::lock_guard<std::mutex> lock(mutex_);
13291309
if (!should_rebuild_buckets() || rebuilt_params_.empty()) {
13301310
return false;
@@ -1381,6 +1361,34 @@ void Reducer::register_comm_hook(std::unique_ptr<CommHookInterface> iface) {
13811361
comm_hook_ = std::move(iface);
13821362
}
13831363

1364+
void Reducer::ensure_prior_reduction_finished() {
1365+
// Check that any prior reduction has finished.
1366+
// The variable `require_finalize_` is true until all gradients
1367+
// have been computed and reduction of all buckets has been kicked off.
1368+
if (require_finalize_) {
1369+
TORCH_CHECK(
1370+
false,
1371+
"Expected to have finished reduction in the prior iteration before ",
1372+
"starting a new one. ",
1373+
"",
1374+
"This error indicates that your module has parameters that were ",
1375+
"not used in producing loss. ",
1376+
"",
1377+
"You can enable unused parameter detection by (1) passing the keyword "
1378+
"argument `find_unused_parameters=True` to ",
1379+
"`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ",
1380+
"`forward` function outputs participate in calculating loss. "
1381+
"",
1382+
"If you already have done the above two steps, then the distributed ",
1383+
"data parallel module wasn't able to locate the output tensors in the ",
1384+
"return value of your module's `forward` function. ",
1385+
"Please include the loss function and the structure of the return ",
1386+
"value of `forward` of your module when reporting this issue (e.g. ",
1387+
"list, dict, iterable).");
1388+
}
1389+
1390+
}
1391+
13841392
namespace {
13851393

13861394
// Tensors may be coalesced into buckets. Buckets must contain tensors of

torch/csrc/distributed/c10d/reducer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ class Reducer {
170170

171171
void finalize_backward();
172172

173+
// Asserts that the reduction for the previous iteration has finished before
174+
// rebuilding buckets or kicking off the next one.
175+
void ensure_prior_reduction_finished();
176+
173177
// Broadcast rebuilt buckets from rank 0 to other ranks before initializing
174178
// the buckets
175179
void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3662,6 +3662,37 @@ def forward(self, x):
36623662
# isolate failure hangs.
36633663
torch.cuda.synchronize(device=self.rank)
36643664

3665+
@require_backend({"gloo", "nccl"})
3666+
@require_backends_available({"gloo", "nccl"})
3667+
@skip_if_lt_x_gpu(2)
3668+
@skip_if_rocm
3669+
def test_ddp_unused_params_rebuild_buckets_exception(self):
3670+
class ToyModel(nn.Module):
3671+
def __init__(self):
3672+
super(ToyModel, self).__init__()
3673+
self.net1 = nn.Linear(10, 10, bias=False)
3674+
self.net2 = nn.Linear(10, 10, bias=False)
3675+
3676+
def forward(self, x):
3677+
return self.net1(x)
3678+
3679+
ddp = torch.nn.parallel.DistributedDataParallel(
3680+
ToyModel().cuda(self.rank), device_ids=[self.rank]
3681+
)
3682+
for i in range(2):
3683+
inp = torch.rand(1, 10)
3684+
if i > 0:
3685+
# On 2nd iteration, this will fail during rebuild_buckets,
3686+
# but we should report an error regarding unused parameters
3687+
# since that is the underlying root cause.
3688+
with self.assertRaisesRegex(
3689+
RuntimeError,
3690+
"Expected to have finished reduction in the prior iteration",
3691+
):
3692+
ddp(inp).sum().backward()
3693+
else:
3694+
ddp(inp).sum().backward()
3695+
36653696
@require_backend({"gloo", "nccl"})
36663697
@require_backends_available({"gloo", "nccl"})
36673698
@skip_if_lt_x_gpu(2)

0 commit comments

Comments
 (0)