-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🚀 The feature, motivation and pitch
Was trying to build two PipelineStage runtimes out of one copy of stage module. The two runtimes differ in the input they expect (e.g. one for prefilling, the other for decoding).
Hit the following error:
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
[rank0]: torch.utils.swap_tensors(param, param_applied)
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/__init__.py", line 91, in swap_tensors
[rank0]: check_use_count(t1, "t1")
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/__init__.py", line 89, in check_use_count
[rank0]: raise RuntimeError(error_str)
[rank0]: RuntimeError: Expected use_count of t1 to be 1 or 2 with an AccumulateGrad node but got 4 make sure you are not holding references to the tensor in other places.
[rank0]: The above exception was the direct cause of the following exception:
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/kw2501/local/torchchat/dist_run.py", line 531, in <module>
[rank0]: main(args)
[rank0]: File "/home/kw2501/local/torchchat/dist_run.py", line 444, in main
[rank0]: decode_stage = PipelineStage(
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 1160, in __init__
[rank0]: self.submod.to(self.device)
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
[rank0]: return self._apply(convert)
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]: module._apply(fn)
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]: module._apply(fn)
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]: module._apply(fn)
[rank0]: [Previous line repeated 1 more time]
[rank0]: File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
[rank0]: raise RuntimeError(
[rank0]: RuntimeError: _apply(): Couldn't swap Linear.weight
The related line of code in pipelining is:
pytorch/torch/distributed/pipelining/stage.py
Line 1261 in 48d18fb
| self.submod.to(self.device) |
Shall we not do this move? My understanding is that it is a convenience feature; but it seems not a big ask if the user were to do it.
Alternatives
An alternative would be for PipelilneStage to support dynamic input sizes. In this case (prefill vs decode), it is the sequence length that varies.
Additional context
cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue