-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[WIP] Add end-to-end test for RNN Module #29543
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. [ghstack-poisoned]
| agent, | ||
| dst, | ||
| std::move(*pythonCall).toMessage(), | ||
| true /*forceGradRecording*/); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding forceGradRecording because UDFs might not carry any requires_grad tensor, but does requires grad on the return value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this fixes #28819. If so, can we have a simple unit test for the problem this fixes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, actually let me split this PR into two, with the first one focusing on fixing #28819 with unit test.
test/dist_model_parallel_test.py
Outdated
| rnn = RNNModel(ps, ntoken, ninp, nhid, nlayers) | ||
| # Depends on #29304 and #28948 | ||
| #opt = DistributedOptimizer( | ||
| # optim.SGD, | ||
| # rnn.remote_parameters(), | ||
| # lr=0.05, | ||
| #) | ||
| with dist_autograd.context() as ctx_id: | ||
| inp = torch.LongTensor(batch, nindices) % ntoken | ||
| output, hidden = rnn(inp, hidden) | ||
| dist_autograd.backward([output.sum()]) | ||
| #opt.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. ghstack-source-id: 6aa362a Pull Request resolved: #29543
test/dist_model_parallel_test.py
Outdated
| args=[RemoteModule.remote_parameters, remote_module_rref] | ||
| ) | ||
|
|
||
| class Encoder(RemoteModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, should this be called something like EmbeddingLayer or RemoteEmbeddingLayer? Usually in seq2seq models encoder refers to the RNN that creates the encoded representation, which in this case would be the LSTM.
| # from torch.distributed.optim import DistributedOptimizer | ||
| # from torch.distributed.rpc import RRef | ||
|
|
||
| from dist_utils import INIT_METHOD_TEMPLATE, dist_init, TEST_CONFIG |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we have a bunch of lint failures.
| @@ -0,0 +1,138 @@ | |||
| import torch | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't the plan to have this in the examples repository? Why is this a unit test instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need RPC module tests as well. I would first like to check with @soumith on whether this is a reasonable example, and whether there are any APIs that we need to revise. And also waiting for the dist optimizer to be landed (it's landing now).
| @@ -0,0 +1,18 @@ | |||
| #!/usr/bin/env python3 | |||
| from __future__ import absolute_import, division, print_function, unicode_literals | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we have two separate files if we only have a spawn mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me merge this into one file.
| TORCH_INTERNAL_ASSERT( | ||
| torch::autograd::compute_requires_grad(tensors), | ||
| "Received tensors do not require grad, addRecvRpcBackward should not be called"); | ||
| if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Were the changes here and in python_functions.cpp bugs that were discovered by the unit test in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I hit errors when running unit test in this PR, but it's kind of expected as we already have #28819 to track this.
| agent, | ||
| dst, | ||
| std::move(*pythonCall).toMessage(), | ||
| true /*forceGradRecording*/); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this fixes #28819. If so, can we have a simple unit test for the problem this fixes?
aazzolini
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe if we do use RemoteModule, it should be a wrapper that lives on the "client" side, not on the "server" side. The way the example is presented here, I'd rather just not introduce the concept of RemoteModule and just call into nn.Module directly.
test/dist_model_parallel_test.py
Outdated
| kwargs=kwargs | ||
| ) | ||
|
|
||
| class RemoteModule(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for RemoteModule in here? I believe RemoteModule only makes sense as a wrapper module that wraps regular modules. The way it's exposed here you'll still have to implement a special module that inherits from RemoteModule, so there's not much utility for this class.
If the only functionality is exposing remote_parameters(), then simply make it a free function, e.g.:
def get_module_param_rrefs(nn_module: NNModule):
return [RRef(param) for param in self.parameters()]
If the goal is to discover remote modules, you could have a RRef wrapper on the client side instead (not on the remote side).
test/dist_model_parallel_test.py
Outdated
| args=[RemoteModule.remote_parameters, remote_module_rref] | ||
| ) | ||
|
|
||
| class Encoder(RemoteModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RemoteModule here is a misnomer, this is a perfectly well defined local module actually. It just happens to be used as a remote module somewhere. Nothing in the implementation of this module implies remote.
test/dist_model_parallel_test.py
Outdated
| return self.drop(self.encoder(input)) | ||
|
|
||
|
|
||
| class RNN(RemoteModule): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. ghstack-source-id: 8482522 Pull Request resolved: #29543
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. ghstack-source-id: e7700dc Pull Request resolved: #29543
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. ghstack-source-id: 5678e2a Pull Request resolved: #29543
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. ghstack-source-id: 033f623 Pull Request resolved: #29543
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. Differential Revision: [D18482428](https://our.internmc.facebook.com/intern/diff/D18482428) [ghstack-poisoned]
The test put encoder and decoder on a remote worker, and put the LSTM module locally. The forward pass 1st looks up the embedding remotely, then fetchs the emb result and runs it through the local LSTM module, and finally sends the output to the remote decoder. The backward pass should automatically traverse through all involved parties. The optimizer takes a list of param RRefs, and reaches each owner to update the params. ghstack-source-id: 8ec4c84 Pull Request resolved: #29543
| import unittest | ||
|
|
||
| @unittest.skipIf(TEST_WITH_ASAN, "Skip ASAN as torch + multiprocessing spawn have known issues") | ||
| class DistModelParallelSpawn(MultiProcessTestCase, DistModelParallelTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, is there a reason we don't have a fork mode test for this?
|
Are we still planning to merge this eventually? Would be super useful to have an end to end test to sanity check changes we make in the RPC layer. |
Stack from ghstack:
The test put encoder and decoder on a remote worker, and put the
LSTM module locally. The forward pass 1st looks up the embedding
remotely, then fetchs the emb result and runs it through the local
LSTM module, and finally sends the output to the remote decoder.
The backward pass should automatically traverse through all
involved parties. The optimizer takes a list of param RRefs, and
reaches each owner to update the params.
Differential Revision: D18482428