Skip to content

Conversation

@xush6528
Copy link
Contributor

@xush6528 xush6528 commented Dec 4, 2019

Stack:
    :black_circle:  #30710 Implement backend-agnostic rpc._wait_all_workers() utility  💚

We need a backend-agnostic mechanism to do barrier-like operation before locally destroy RRef context and shutdown RPC Agent.

  • Sort worker names.
  • Elect the first name as the leader in the ordered worker names.
  • Followers reports therir intent to synchronize to the leader.
  • Leader also reports to itself, when _wait_all_workers() called.
  • If all workers report their intent to proceed, leader send the command to every one to proceed.

Differential Revision: D18643137

Differential Revision: D18643137
Differential Version: 94902955
@mrshenli
Copy link
Contributor

mrshenli commented Dec 4, 2019

comments are added in #30693. are we abandoning that one?

Differential Revision: D18643137
Differential Version: 94957803
Differential Revision: D18643137
Differential Version: 94960164
if graceful:
_wait_all_workers()
_destroy_rref_context(_ignore_rref_leak)
_agent.shutdown()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I guess, with the current implementation, we won't wait for the very last RPC sent by the leader worker to be processed, since _agent.shutdown() does not guarantee that it will wait for outstanding RPCs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma For now, we are using a sleep.

Differential Revision: D18643137
Differential Version: 94961619
Differential Revision: D18643137
Differential Version: 94961876
if is_leader_worker:
# The leader sends out proceeed signals to all followers.
for follower_worker_name in _ALL_WORKER_NAMES - {leader_worker_name}:
rpc_async(follower_worker_name, _set_proceed_signal, args=())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this line the leader will do a local shutdown and it may lead to _set_proceed_signal not go out to the nodes as the local agent may be killed. What can we do here to avoid this other than making leader wait for an arbitrary time ?

Differential Revision: D18643137
Differential Version: 95031963
Differential Revision: D18643137
Differential Version: 95067616
Differential Revision: D18643137
Differential Version: 95109553
Differential Revision: D18643137
Differential Version: 95201140
Differential Revision: D18643137
Differential Version: 95231363
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this!

rpc.shutdown()

@dist_init(clean_shutdown=False)
def test_wait_all_workers(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a unit test where we don't setup RPC as part of dist_init and do this manually in the unit test:

  1. Call init_rpc.
  2. Do some work.
  3. Call rpc.shutdown()

Also, I'd recommend running the tests we add here about 100 times locally to ensure we don't have any flakiness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pritamdamania87

If I repeatedly init_rpc(), I got this error.

RuntimeError: Container is already initialized! Cannot initialize it twice! (init at caffe2/torch/csrc/distributed/autograd/context/container.cpp:38)

Dist Autograd container is not destroyed after shutting down RPC.

Can we just run once for now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xush6528 Oh I didn't mean to run the test in a loop here. I meant just run it 100 times locally in your console to ensure there is no flakiness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pritamdamania87

Tested locally. It's passing and pretty stable.


if graceful:
_wait_all_workers()
_agent.join()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do this? This is not implemented for all agents?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a RRef leak issue we need to fix before we can remove this line.

Two steps,

  • Make all nexted RPC calls wait on futures they created so that they are chained.
  • Add a API to RRefContext for proactively cleaning up local forks. (I know how to do this, I can help here. Will coordinate with @mrshenli )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate what the RRef leak issue is and why we don't see it in other agents where the join() is pretty much a no-op.

Copy link
Contributor Author

@xush6528 xush6528 Dec 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pritamdamania
The problem is now we depends on Python GC whose timing we don't have control on.
So delete fork messages could be sent out after shutting down RPC.

It looks good on master branch now, only because in other agent, we sleep for 2 seconds, which is sufficient for 1) Python interpret to trigger GC and 2) the delete fork messages are flushed into wire.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can we create a github issue describing the problem and the plan to fix it? We shouldn't have this 2 second sleep long term.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Differential Revision: D18643137
Differential Version: 95378112
Differential Revision: D18643137
Differential Version: 95395160
Differential Revision: D18643137
Differential Version: 95403917
if is_leader_worker:
# The leader sends out proceeed signals to all followers.
timeout = timedelta(seconds=5)
_set_rpc_timeout(timeout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is needed? The default timeout should be 60s, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is nearly half the time there is at least a follower failed to respond. So the leader always noticed a timeout.

I think this is for making shutdown not take 60s every time.

for follower_worker_name, fut in worker_name_to_response_future_dict:
try:
fut.wait()
except RuntimeError as ex:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the future fail in other ways besides timeout while resolving? If so, are we interested in catching those?

Copy link
Contributor Author

@xush6528 xush6528 Dec 16, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohan-varma
For most cases, followers don't send response because followers are calling local ::shutdown() on receiving the shutdown request from the leader.
Followers could not respond the shutdown because of other reasons.

The best the leader can do is to send shutdown request again. But if there is a follower always fails to respond, there is nothing for the leader side to do.

@kostmo
Copy link
Member

kostmo commented Dec 12, 2019

💊 CircleCI build failures summary and remediations

As of commit 45aa57f:

None of the build failures appear to be your fault.

  • 1/1 recognized as flaky ❄️
    • Re-run these jobs?

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

❄️ 1 failure recognized as flaky

The following build failures have been detected as flaky and may not be your fault:

See CircleCI build pytorch_linux_xenial_cuda9_cudnn7_py3_test (1/1)

Step: "Test" (full log | pattern match details) ❄️

Jan 05 00:35:27 AssertionError: 1024 not less than or equal to 1e-05 : __main__.TestAutogradDeviceTypeCUDA.test_logdet_1x1_cuda leaked 1024 bytes CUDA memory on device 0
Jan 05 00:35:27 ====================================================================== 
Jan 05 00:35:27 FAIL [0.124s]: test_logdet_1x1_cuda (__main__.TestAutogradDeviceTypeCUDA) 
Jan 05 00:35:27 ---------------------------------------------------------------------- 
Jan 05 00:35:27 Traceback (most recent call last): 
Jan 05 00:35:27   File "/var/lib/jenkins/workspace/test/common_utils.py", line 665, in wrapper 
Jan 05 00:35:27     method(*args, **kwargs) 
Jan 05 00:35:27   File "/var/lib/jenkins/workspace/test/common_utils.py", line 521, in __exit__ 
Jan 05 00:35:27     self.name, after - before, i)) 
Jan 05 00:35:27   File "/var/lib/jenkins/workspace/test/common_utils.py", line 877, in assertEqual 
Jan 05 00:35:27     super(TestCase, self).assertLessEqual(abs(x - y), prec, message) 
Jan 05 00:35:27 AssertionError: 1024 not less than or equal to 1e-05 : __main__.TestAutogradDeviceTypeCUDA.test_logdet_1x1_cuda leaked 1024 bytes CUDA memory on device 0 
Jan 05 00:35:27  
Jan 05 00:35:27 ---------------------------------------------------------------------- 
Jan 05 00:35:27 Ran 1884 tests in 885.054s 
Jan 05 00:35:27  
Jan 05 00:35:27 FAILED (failures=1, skipped=13, expected failures=1) 
Jan 05 00:35:27  
Jan 05 00:35:27 Generating XML reports... 
Jan 05 00:35:27 Traceback (most recent call last): 
Jan 05 00:35:27   File "test/run_test.py", line 456, in <module> 
Jan 05 00:35:27     main() 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 22 times.

Differential Revision: D18643137
Differential Version: 95738402
Differential Revision: D18643137
Differential Version: 95740293
Differential Revision: D18643137
Differential Version: 95751085
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes since I'm concerned about the time.sleep(0.2). We shouldn't rely on a sleep() call like this.

Comment on lines 158 to 162
# This is a hack to make the follower linger for a while to finish
# sending out the last response message.
import time

time.sleep(0.2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really need to do this since on the server side we catch an exception and the shutdown still proceeds without any issues? We really should not put any sort of sleeps like this in production code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. Updating.


if graceful:
_wait_all_workers()
_agent.join()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, can we create a github issue describing the problem and the plan to fix it? We shouldn't have this 2 second sleep long term.

Differential Revision: D18643137
Differential Version: 95770571
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks good to me, although there are a bunch of lint issues and CI failures. Could you fix those before landing?

facebook-github-bot pushed a commit that referenced this pull request Dec 19, 2019
Summary:
#30330 got rid of the need to send a `MessageType::SHUTDOWN` message, so we can now remove the logic/utils for this type of message.

I think we can also delete the enum entry in the `enum MessageType`, but we may want to keep it in case the logic in #30710 is ever moved to C++.
Pull Request resolved: #31270

Test Plan: All existing unit tests pass

Differential Revision: D19146983

Pulled By: rohan-varma

fbshipit-source-id: 35b185411f9446d7d4dfc37a6cb5477cf041e647
@jjlilley
Copy link

So, an alternate approach would be c++ barrier solution in Thrift, like [newly] pending D19187645

While not agent-independent, such an approach has the opportunity of being more robust. Particularly, it can use onRequestSent() messages from Thrift, as a signal that the bits have gone out, which we have no access to at the Python layer. Getting this sort of signal seemed to be more of the trickier areas with shutdown-barrier support.

@pritamdamania87
Copy link
Contributor

@jjlilley I feel we should keep this logic agent agnostic. We shouldn't force every backend to implement a barrier (ex: we have TensorPipes coming soon).

Differential Revision: D18643137
Differential Version: 96298855
Differential Revision: D18643137
Differential Version: 96298893
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 502533c.

@mrshenli
Copy link
Contributor

mrshenli commented Jan 5, 2020

test_rref_leak failed on master. Sorry that I lost the context on this one. Do we why calling _agent.join() in shutdown is not sufficient? cc @xush6528

https://app.circleci.com/jobs/github/pytorch/pytorch/4135881

Jan 05 04:08:40 ======================================================================
Jan 05 04:08:40 FAIL [1.653s]: test_rref_leak (__main__.RpcTestWithSpawn)
Jan 05 04:08:40 ----------------------------------------------------------------------
Jan 05 04:08:40 Traceback (most recent call last):
Jan 05 04:08:40   File "/var/lib/jenkins/workspace/test/common_distributed.py", line 130, in wrapper
Jan 05 04:08:40     self._join_processes(fn)
Jan 05 04:08:40   File "/var/lib/jenkins/workspace/test/common_distributed.py", line 211, in _join_processes
Jan 05 04:08:40     self._check_return_codes(elapsed_time)
Jan 05 04:08:40   File "/var/lib/jenkins/workspace/test/common_distributed.py", line 231, in _check_return_codes
Jan 05 04:08:40     self.assertEqual(p.exitcode, first_process.exitcode)
Jan 05 04:08:40   File "/var/lib/jenkins/workspace/test/common_utils.py", line 877, in assertEqual
Jan 05 04:08:40     super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
Jan 05 04:08:40 AssertionError: 11 not less than or equal to 1e-05 : 

@jjlilley
Copy link

jjlilley commented Jan 6, 2020

@jjlilley I feel we should keep this logic agent agnostic. We shouldn't force every backend to implement a barrier (ex: we have TensorPipes coming soon).

Agree we shouldn't force every backend to implement a barrier; having a generic impl has a benefit.

But for certain backends, I expect we can override with a better (more reliable/etc) solution that uses lower-level primatives.

@facebook-github-bot facebook-github-bot deleted the export-D18643137 branch January 8, 2020 15:17
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Summary:
pytorch#30330 got rid of the need to send a `MessageType::SHUTDOWN` message, so we can now remove the logic/utils for this type of message.

I think we can also delete the enum entry in the `enum MessageType`, but we may want to keep it in case the logic in pytorch#30710 is ever moved to C++.
Pull Request resolved: pytorch#31270

Test Plan: All existing unit tests pass

Differential Revision: D19146983

Pulled By: rohan-varma

fbshipit-source-id: 35b185411f9446d7d4dfc37a6cb5477cf041e647
wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
…0710)

Summary:
Pull Request resolved: pytorch#30710

We need a backend-agnostic mechanism to do barrier-like operation before locally destroy RRef context and shutdown RPC Agent.

- Sort worker names.
- Elect the first name as the leader in the ordered worker names.
- Followers reports therir intent to synchronize to the leader.
- Leader also reports to itself, when `_wait_all_workers()` called.
- If all workers report their intent to proceed, leader send the command to every one to proceed.

Test Plan:
# Unit tests

```
buck test mode/dev-nosan //caffe2/test:rpc_fork -- test_wait_all_workers

buck-out/gen/caffe2/test/rpc_fork\#binary.par -r test_wait_all_workers$
buck-out/gen/caffe2/test/rpc_fork\#binary.par -r test_rref_leak
buck-out/gen/caffe2/test/rpc_fork\#binary.par -r test_rref_forward_chain
```

```
buck test mode/dev-nosan //caffe2/test:rpc_fork_thrift -- test_wait_all_workers

buck-out/gen/caffe2/test/rpc_fork_thrift\#binary.par -r test_wait_all_workers$
```

# Stress runs
```
buck test mode/dev-nosan //caffe2/test:rpc_fork_thrift -- test_stress_light_rpc --stress-runs 10
```

```
buck test mode/dev-nosan //caffe2/test:rpc_spawn_thrift -- test_stress_light_rpc --stress-runs 10
```

```
buck test mode/dev-nosan //caffe2/test:rpc_fork_thrift -- test_stress_heavy_rpc --stress-runs 10
```

```
buck test mode/dev-nosan //caffe2/test:rpc_spawn_thrift -- test_stress_heavy_rpc --stress-runs 10
```

# Debug

```
buck test mode/dev-nosan caffe2/test:rpc_fork -- test_shutdown
```

```
buck test mode/dev-nosan //caffe2/test:dist_autograd_fork -- test_clean_context_during_backward

buck build mode/dev-nosan //caffe2/test:dist_autograd_fork

buck-out/gen/caffe2/test/dist_autograd_fork\#binary.par -r test_clean_context_during_backward
```

https://our.intern.facebook.com/intern/testinfra/diagnostics/281475127895800.844424945328750.1575664368/

```
I1206 12:27:47.491420 185619 process_group_agent.cpp:211] Shutting down ProcessGroupAgent.
I1206 12:27:47.493880 185630 process_group_agent.cpp:211] Shutting down ProcessGroupAgent.
I1206 12:27:47.494526 185625 process_group_agent.cpp:211] Shutting down ProcessGroupAgent.
I1206 12:27:47.495390 185636 process_group_agent.cpp:211] Shutting down ProcessGroupAgent.
E1206 12:27:47.544198 185627 pair.cc:642] 1 --->>> 0, read ERROR: AsyncSocketException: Network error, type = Network error, errno = 104 (Connection reset by peer)
E1206 12:27:47.544203 185633 pair.cc:642] 2 --->>> 0, read ERROR: AsyncSocketException: Network error, type = Network error, errno = 104 (Connection reset by peer)
E1206 12:27:47.544210 185639 pair.cc:642] 3 --->>> 0, read ERROR: AsyncSocketException: Network error, type = Network error, errno = 104 (Connection reset by peer)
```
This should mean the UDF in the request has been run, so Python proceeded and ran to `_agent.shutdown()`.

While the RpcAgents on followers wanted to send back the response, but the leader has closed RPC.

Need to re-trigger "pytorch_rpc-buck" to reproduce the rare-seen issue.

Differential Revision: D18643137

fbshipit-source-id: d669d4fc9ad65ed48bed1329a4eb1c32ba51323c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants