Skip to content

Conversation

@lukeyeager
Copy link
Contributor

Allows you to use a bash script wrapper in-between launch and your
training script. e.g.

python -m torch.distributed.launch --nproc_per_node=8 --no_python --use_env \
    bash -c 'exec numactl --cpunodebind=$(( LOCAL_RANK / 4 )) "$@"' -- \
    python train.py ...

Allows you to use a bash script wrapper in-between launch and your
training script. e.g.
```
python -m torch.distributed.launch --nproc_per_node=8 --no_python --use_env \
    bash -c 'exec numactl --cpunodebind=$(( LOCAL_RANK / 4 )) "$@"' -- \
    python train.py ...
```
@pietern
Copy link
Contributor

pietern commented Nov 6, 2019

Thanks, @lukeyeager. Looks good to me.

It's unfortunate that the tool is becoming a pile of backward compat hacks, though. I think a next version of this thing would 1) never prepend Python, and 2) always use the environment.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@pietern pietern added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 6, 2019
@lukeyeager lukeyeager requested a review from pietern November 11, 2019 19:03
@lukeyeager
Copy link
Contributor Author

lukeyeager commented Nov 15, 2019

Temporary workaround 😁

python -m torch.distributed.launch --nproc_per_node=8 --use_env -- -c \
    "import os, sys, subprocess; \
    ret = subprocess.run(['numactl', '--cpunodebind={}'.format(int(int(os.environ['LOCAL_RANK'])/4)), *sys.argv[1:]]); \
    sys.exit(ret.returncode)" python train.py ...

@pietern Thoughts on the more serious proposals above?

@pietern
Copy link
Contributor

pietern commented Nov 20, 2019

Thanks for investigating, @lukeyeager. Unfortunate that we can't have a bool option in argparse...

I like the second option best: a --no_python option and negation in code for readability.

Regarding a v2, we could create a parallel tool called torch.distributed.run start from scratch. Then there won't be any BC issues and we can make it do more stuff as needed. Thoughts?

@lukeyeager
Copy link
Contributor Author

I like the second option best

Great! Done.

Regarding a v2, we could create a parallel tool called torch.distributed.run start from scratch. Then there won't be any BC issues and we can make it do more stuff as needed. Thoughts?

I don't really have any feelings about that. I haven't been using pytorch much so I don't have much context. I do know there's already torch.nn.parallel.DistributedDataParallel which is (was?) "better" (why?) in some circumstances than torch.distributed.launch. Adding a third thing might be even more confusing? I expect others at NVIDIA might have more well-formed thoughts. I know several people want something which plays nicely with MPI. I would hate for that to become a fourth launch option.

@pietern
Copy link
Contributor

pietern commented Nov 22, 2019

I do know there's already torch.nn.parallel.DistributedDataParallel which is (was?) "better" (why?) in some circumstances than torch.distributed.launch.

They are complementary. You might be thinking of nn.DataParallel (single process, multi GPU, not distributed) and nn.DistributedDataParallel (single process, single or multi GPU, distributed). To launch jobs that use the latter, you need either mpirun, srun, or if you DIY something, you can use torch.distributed.launch. As you see in the code, its job is really simple, and it just launches N processes with a local rank argument / environment variable.

Adding a third thing might be even more confusing? I expect others at NVIDIA might have more well-formed thoughts. I know several people want something which plays nicely with MPI. I would hate for that to become a fourth launch option.

We could have some kind of ptrun frontend for MPI / Slurm / DIY. I know Horovod has hvdrun and does something similar with environment detection and launching a job. This is definitely an area we can improve on.

Thanks for updating the PR.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pietern is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants