Skip to content

Conversation

@aazzolini
Copy link
Contributor

@aazzolini aazzolini commented Oct 30, 2019

Stack from ghstack:

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.

Differential Revision: D18230877

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.

Differential Revision: [D18230877](https://our.internmc.facebook.com/intern/diff/D18230877/)

[ghstack-poisoned]
aazzolini added a commit that referenced this pull request Oct 30, 2019
Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.

Differential Revision: [D18230877](https://our.internmc.facebook.com/intern/diff/D18230877/)

ghstack-source-id: 92927594
Pull Request resolved: #28910
@pritamdamania87 pritamdamania87 self-requested a review October 30, 2019 22:54
@aazzolini
Copy link
Contributor Author

@pytorchbot retest this please

raise ValueError('Error running optimizer.')


def _call_meth(meth, obj_rref, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we use method/func instead of meth? :)


def rpc_async_meth(meth, obj_rref, *args, **kwargs):
"""
Call rpc.remote on a method in a remote object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Fix docs, its the same as remote_meth above. I believe it should be rpc.async here.


@dist_init()
def test_dist_optim(self):
if self.rank != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we running this only on one node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No specific reason, will remove.


import unittest

@unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a test_dist_optimizer_fork.py file to ensure we still run ASAN. Or we should run opt-asan for spawn

from collections import defaultdict

class FunctionalOptimizer:
"""Base class for functional optimizers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Start doc on next line:

"""
Base class

raise NotImplementedError


class FunctionalSGD(FunctionalOptimizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this class in a separate file?

Comment on lines 88 to 89
args: arguments to pass to the optimizer constructor on each worker.
kwargs: arguments to pass to the optimizer constructor on each worker.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we pass an instance of the optimizer instead of the class and its args? We can create an RRef of this instance on the remote node by just passing it in as a parameter to a remote call and just returning the same object back.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be possible -- optimizer objects (both torch.optim and the FunctionalOptimizer that I introduced above) take local model parameters as input to the constructor. However, we don't have access to those parameters from here. The worker where the parameters live is the only one capable of passing those parameters to the constructor of the optimizer.

Alternativelly we could introduce an OptimizerConfig class but it wouldn't solve the underlying issue.

self.remote_optimizers.append(remote_optim)


def step(self, autograd_ctx_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we keep this API inline with backward where we implicitly use the current context id instead of passing it in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that dist_autograd._current_context() is private. Should we expose it publicly? (In a subsequent PR).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep it private and still call it from here? There is no reason to expose this method to end users at the moment.

try:
fut.wait()
except Exception as e:
exception = e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason that we continue waiting for other futures if we see an exception? Shouldn't we just exit as soon as possible and not wait for other futures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just not sure what would happen if one of these operations is still executing remotely and we exit the current autograd context.



def _wait_for_all(rpc_futs):
# TODO: improve error propagation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What sort of improvements do we have in mind here? Also, can we create a github issue for this so we can keep track of this for the future?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we want to gather all exceptions in a list. I'll open an issue for it.

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.

Differential Revision: [D18230877](https://our.internmc.facebook.com/intern/diff/D18230877/)

[ghstack-poisoned]
@kostmo
Copy link
Member

kostmo commented Oct 31, 2019

CircleCI build failures summary

As of commit a8acb05:

  • 0/1 flaky
  • 1/1 failures introduced in this PR

Here are the reasons each build failed.


This comment was automatically generated by Dr. CI.
Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 1 time(s).

from dist_utils import INIT_METHOD_TEMPLATE, dist_init
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.optimizer.dist_optimizer as dist_optimizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The corresponding package and class is called torch.optim.Optimizer, do we want to keep the same convention here? i.e., torch.distributed.optim.Optimizer.



@unittest.skipIf(
not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two unnecessary double quotes, or maybe you were intended to break them into shorter lines.

remote_param1 = remote_meth(MyModule.get_w, remote_module1)
remote_param2 = remote_meth(MyModule.get_w, remote_module2)

dst_optim = dist_optimizer.DistributedOptimizer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dst_optim -> dist_optim?

t2 = torch.rand((3, 3), requires_grad=True)
output1 = rpc_async_meth(MyModule.forward, remote_module1, t2)
output2 = rpc_async_meth(
MyModule.forward, remote_module2, output1.wait())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In followup PRs, we also want to test passing an RRef of output1 to remote_module2, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I was waiting for the .to_here() autograd propagation PR to be landed. If it's already the case, I can add the test to this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's landed, but we can still add new tests in a PR on top of this. There are some flakiness that we are investigating, which should not block us from landing this PR in its current form.

def __init__(self, params):
self.params = params

def step(self, gradients):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for directly taking the gradients here instead of a distributed autograd context id, and use that id to retrieve gradients?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that this class is a local optimizer, it doesn't know about distributed anything.


from collections import defaultdict

class FunctionalOptimizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would I be correct if I assume there is no easy way to reuse existing optimizers without modifying them as they directly reads from param.grad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking of writing an adapter that allows to use existing optim.Optimizer. In order to avoid race conditions I'll have to create multiple instance of optim.Optimizer taking the same parameters but unsharing the gradients. I can open an issue explaining the idea in more details.

specific parameters.

Args:
optimizer_class (FunctionalOptimizer): the class of optimizer to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the constructor take a torch.optim class as input instead?

I also would prefer that way.

hogwild case there could be a race in the proposal above (I'm not sure if that matters though).

This probably won't work in hogwild without a lock as grad might be overwritten by other threads, but as hogwild is a more advanced use case, it seems reasonable to ask hogwild applications to implement their own local optimizers that do not directly read from param.grad? We could have a mode flag here to toggle whether we want to use a lock.

Args:
optimizer_class (FunctionalOptimizer): the class of optimizer to
instantiate on each worker.
params_rref (list[RRef]): list of RRefs to local or remote parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would I be correct if I assume that the RemoteModule we discussed offline should have a parameters() API that returns a list of RRefs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. I didn't want to introduce a full fledged RemoteModule in this diff; also we'll need a prep PR before we can return a list of RRefs.

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.

Differential Revision: [D18230877](https://our.internmc.facebook.com/intern/diff/D18230877/)

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Oct 31, 2019

Hi I was tagged reviewer here. It looks like there are already reviews. Is there any specific you want me to look at?

@pritamdamania87
Copy link
Contributor

@ezyang I tagged you as a reviewer since this is somewhat related to distributed autograd, although looks like @vincentqb @soumith are the owners for torch.optim based on this: https://pytorch.org/docs/stable/community/persons_of_interest.html#torch-optim

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.

Differential Revision: [D18230877](https://our.internmc.facebook.com/intern/diff/D18230877/)

[ghstack-poisoned]
@ezyang
Copy link
Contributor

ezyang commented Nov 1, 2019

Yes, why don't you tag @vincentqb for review her, for optimizer notes.

@mrshenli mrshenli requested a review from vincentqb November 1, 2019 14:50
@mrshenli
Copy link
Contributor

mrshenli commented Nov 1, 2019

Hey @vincentqb, could you please help to take a look at the optimizer API in this PR?

@vincentqb
Copy link
Contributor

Can the constructor take a torch.optim class as input instead?

What was the conclusion of this discussion?

As mentioned above, it would be really nice to have a mechanism to re-use the current optimizers. Wrappers to offer {,non-}locking could be nice, but a little heavy to use. Offering a default through a toggle may be a solution if we don't expect more toggles to appear with time.

Also, if possible to keep the syntax as close as possible, it'd be nice to have a way to simply replace opt = torch.optim.SGD(...) by opt = torch.distributed.optim.FunctionalSGD(...) to get a FunctionalOptimizer version of an Optimizer. I highlighted some differences in the parameters passed.

Thoughts?



class FunctionalSGD(FunctionalOptimizer):
"""Simplistic implementation of Stocastic Gradient Descent optimizer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Stochastic

def __init__(self, params):
self.params = params

def step(self, gradients):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current Optimizer takes and returns the loss function. What would be an equivalent here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately there's no straightforward way of porting the closure() functionality to distributed as it would potentially require algorithm-specific synchronization across workers. This would be true for L-FBGS for example, where closure() and parameter updates are interleaved in a loop inside of step().

matters as the list of gradients passed to the step
function must be aligned with this list.
"""
def __init__(self, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would having "defaults" make sense here to mimic Optimizer?

self.remote_optimizers.append(remote_optim)


def step(self, autograd_ctx_id):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as FunctionalOptimizer.step


from collections import defaultdict

class FunctionalOptimizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this precise enough as a name? Or could we name it differently, say LocalOptimizer or NonLockingOptimizer?

@aazzolini
Copy link
Contributor Author

Proposal:

  • I'll rename FunctionalOptimizer to LocalOptimizer
  • I'll keep LocalOptimizer as simple as it is right now, in particular:
    • I won't add "defaults" as argument constructor as 1) it would require code to be added to the base class but I'd prefer to keep it interface-only; 2) it would be unclear how to wrap an existing Optimizer as a LocalOptimizer since we'd have 2 instances of the "defaults" field. LocalOptimizer interface-only makes it cleaner to implement as a wrapper;
  • I won't add the "closure" argument to step() because there's no sensible default way of implementing this functionality for distributed optimizers.
  • I'll implement a thin "LockingOptimizer" wrapper that will wrap a optim.Optimizer as a LocalOptimizer. If a optim.Optimzier is passed to DistributedOptimizer it will be automatically wrapped.
  • I won't re-implement any specific Optimizer on top of LocalOptimizer (except in test cases).

Let me know if that works.

@aazzolini
Copy link
Contributor Author

@vincentqb could you comment on this proposal since we'd like to have this PR landed soon for 1.4 release?

@aazzolini aazzolini closed this Nov 5, 2019
@aazzolini
Copy link
Contributor Author

Synced up offline with @mrshenli -- instead i'll go the simpler route of integrating with "optim.Optimizer" directly to avoid introducing new APIs.

@vincentqb
Copy link
Contributor

@aazzolini -- did you mean to close this PR?

@vincentqb could you comment on this proposal since we'd like to have this PR landed soon for 1.4 release?

The proposal, with the amendment, sounds good to me. Thanks for working on this!

aazzolini added a commit that referenced this pull request Nov 6, 2019
Pull Request resolved: #28910

Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.
ghstack-source-id: 93377263

Differential Revision: [D18230877](https://our.internmc.facebook.com/intern/diff/D18230877/)
@aazzolini
Copy link
Contributor Author

This is weird, I must have closed this by mistake. It won't let me re-open it mentioning the commits have already been merged but that doesn't seem to be the case.
I'll open a new PR.

@aazzolini
Copy link
Contributor Author

@vincentqb @pritamdamania87 @mrshenli I produced #29304 to continue the discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants