Skip to content

RRef Leak and pickler issues with Distributed Optimizer #30022

@pritamdamania87

Description

@pritamdamania87

🐛 Bug

Script shared below runs into RRef leaks and pickler issues when running the DistributedOptimizer

To Reproduce

import torch
from torch.distributed import rpc
import torch.distributed.autograd as dist_autograd
from torch import optim
from torch.distributed.optim import DistributedOptimizer
from tempfile import NamedTemporaryFile
import multiprocessing as mp

def random_tensor():
    return torch.rand((3, 3), requires_grad=True)

def _run_process(self_rank, dst_rank, file_name):
    self_name = "worker{}".format(self_rank)
    dst_name = "worker{}".format(dst_rank)

    rpc.init_model_parallel(
        self_name=self_name,
        self_rank=self_rank,
        worker_name_to_id={"worker0": 0, "worker1": 1},
        init_method="file://{}".format(file_name),
    )

    with dist_autograd.context() as context_id:
       # Forward pass.
       rref1 = rpc.remote(dst_name, random_tensor)
       rref2 = rpc.remote(dst_name, random_tensor)
       loss = rref1.to_here() + rref2.to_here()

       # Backward pass.
       dist_autograd.backward([loss.sum()])

       # Optimizer.
       try:
        dist_optim = DistributedOptimizer(
            optim.SGD,
            [rref1, rref2],
            lr=0.05,
        )
       except Exception as e:
           print (e)
           raise e
       dist_optim.step()

    rpc.join_rpc()

file_name = NamedTemporaryFile().name
processes = []
for i in range(2):
    p = mp.Process(target=_run_process, args=(i, (i + 1) % 2, file_name))
    p.start()

for p in processes:
    p.join()

The code above fails with: https://gist.github.com/pritamdamania87/08cf8de0aa95de03ba1e50da89f96996

It looks like when we create RRefs in the DistributedOptimizer constructor, we don't "wait" for these RRefs. Aside from the issue above, we should ensure we wait for the rrefs so that we discover any errors earlier. But now, if I change the DistributedOptimizer code to wait for rrefs, I run into the following error with the same script above:

https://gist.github.com/pritamdamania87/e637f5f8a5fede8a8102c584191c28b3

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528

Metadata

Metadata

Assignees

Labels

module: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions