-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Simple distributed optimizer #29304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple distributed optimizer #29304
Conversation
Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. Differential Revision: [D18354586](https://our.internmc.facebook.com/intern/diff/D18354586/) [ghstack-poisoned]
Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. Differential Revision: [D18354586](https://our.internmc.facebook.com/intern/diff/D18354586/) ghstack-source-id: 93381365 Pull Request resolved: #29304
|
This PR replaces #28910 which was closed by mistake and couldn't be reopened. For this version, I got rid of FunctionalOptimizer and made torch.optim.Optimizer work directly with DistributedOptimizer, simplifying the implementation and reducing new API footprint. |
torch/distributed/optim/optimizer.py
Outdated
| kwargs: arguments to pass to the optimizer constructor on each worker. | ||
| """ | ||
| def __init__(self, optimizer_class, params_rref, *args, **kwargs): | ||
| per_worker_params_rref = defaultdict(lambda: []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: defaultdict(list)
pietern
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is very simple indeed :)
test/dist_optimizer_test.py
Outdated
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken this test cannot run by itself, so this should go.
| [rref.local_value().wait() for rref in local_params_rref], | ||
| *args, | ||
| **kwargs) | ||
| self.lock = Lock() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the lock mentioned here, right? How would we modify this to get Hogwild?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. There's no simple way of modifying the current implementation to get hogwild. We'd have to either 1) change the interface (FunctionalOptimizer); or some alternative such as keeping an object pool / or thread local instances of optim.Optimizers to avoid gradient sharing across threads. I can write a comment on why the lock is there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, just wanted to verify. A quick comment would be a good idea indeed. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't removing this lock give us some form of hogwild? I'm not sure how accurate this example is: https://pytorch.org/docs/stable/notes/multiprocessing.html#hogwild, although it does something similar where accumulating gradients on the tensors and running the optimizer could interleave among different processes.
|
What is the plan for adding a tutorial / example? |
We need thorough docstrings for this one, then we can have a single tutorial for full training I think. |
| with self.lock: | ||
| for param, grad in all_local_grads.items(): | ||
| param.grad = grad | ||
| self.optim.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is nice!
| params_rref (list[RRef]): list of RRefs to local or remote parameters | ||
| to optimize. | ||
| args: arguments to pass to the optimizer constructor on each worker. | ||
| kwargs: arguments to pass to the optimizer constructor on each worker. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add an example here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that a convincing example will need a way to 1) create the remote module; 2) call it; 3) get a list of RRef params for it. We can add this after we have closed on a way of doing these.
torch/distributed/optim/optimizer.py
Outdated
| specific parameters. | ||
|
|
||
| Args: | ||
| optimizer_class (FunctionalOptimizer): the class of optimizer to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no FunctionalOptimizer any more, do you mean optim.Optimizer?
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM! My comments are mostly on docs.
| to optimize. | ||
| args: arguments to pass to the optimizer constructor on each worker. | ||
| kwargs: arguments to pass to the optimizer constructor on each worker. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add a warning that we only make sure concurrent step() calls will not modify the same param at the same time, but they could still modify params on different owners in an interleaving way? This means that when an dist optimizer tries to apply a grad x to a param, the param might already be different from when grad x was computed. This behavior is by design. If the application needs to do global exclusive dist optimizer step(), they will have to synchronize it on their own.
Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. Differential Revision: [D18354586](https://our.internmc.facebook.com/intern/diff/D18354586/) [ghstack-poisoned]
Pull Request resolved: #29304 Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. ghstack-source-id: 93487483 Differential Revision: [D18354586](https://our.internmc.facebook.com/intern/diff/D18354586/)
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Internal test failure is real
Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. Differential Revision: [D18354586](https://our.internmc.facebook.com/intern/diff/D18354586/) [ghstack-poisoned]
Pull Request resolved: #29304 Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized. It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel. ghstack-source-id: 93564364 Differential Revision: [D18354586](https://our.internmc.facebook.com/intern/diff/D18354586/)
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test all passed, let's land this!
|
This pull request has been merged in b0cf43b. |
Stack from ghstack:
Implements a simple python distributed optimizer that takes rrefs to parameters that will be optimized.
It keeps instances of optimizers remotely and calling step on distributed optimizer will call step on each of the remote optimizers in parallel.
Differential Revision: D18354586