-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizerRelated to RPC, distributed autograd, RRef, and distributed optimizertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
Script shared below runs into RRef leaks and pickler issues when running the DistributedOptimizer
To Reproduce
import torch
from torch.distributed import rpc
import torch.distributed.autograd as dist_autograd
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from tempfile import NamedTemporaryFile
import multiprocessing as mp
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(self_rank, dst_rank, file_name):
self_name = "worker{}".format(self_rank)
dst_name = "worker{}".format(dst_rank)
rpc.init_model_parallel(
self_name=self_name,
self_rank=self_rank,
worker_name_to_id={"worker0": 0, "worker1": 1},
init_method="file://{}".format(file_name),
)
with dist_autograd.context() as context_id:
# Forward pass.
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass.
dist_autograd.backward([loss.sum()])
# Optimizer.
try:
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
except Exception as e:
print (e)
raise e
dist_optim.step()
rpc.join_rpc()
file_name = NamedTemporaryFile().name
processes = []
for i in range(2):
p = mp.Process(target=_run_process, args=(i, (i + 1) % 2, file_name))
p.start()
for p in processes:
p.join()
The code above fails with: https://gist.github.com/pritamdamania87/08cf8de0aa95de03ba1e50da89f96996
It looks like when we create RRefs in the DistributedOptimizer constructor, we don't "wait" for these RRefs. Aside from the issue above, we should ensure we wait for the rrefs so that we discover any errors earlier. But now, if I change the DistributedOptimizer code to wait for rrefs, I run into the following error with the same script above:
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528
Metadata
Metadata
Assignees
Labels
module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizerRelated to RPC, distributed autograd, RRef, and distributed optimizertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module