Skip to content

[pipelining] create two stage instances out of same stage module #136225

@kwen2501

Description

@kwen2501

🚀 The feature, motivation and pitch

Was trying to build two PipelineStage runtimes out of one copy of stage module. The two runtimes differ in the input they expect (e.g. one for prefilling, the other for decoding).

Hit the following error:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 945, in _apply
[rank0]:     torch.utils.swap_tensors(param, param_applied)
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/__init__.py", line 91, in swap_tensors
[rank0]:     check_use_count(t1, "t1")
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/utils/__init__.py", line 89, in check_use_count
[rank0]:     raise RuntimeError(error_str)
[rank0]: RuntimeError: Expected use_count of t1 to be 1 or 2 with an AccumulateGrad node but got 4 make sure you are not holding references to the tensor in other places.

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/kw2501/local/torchchat/dist_run.py", line 531, in <module>
[rank0]:     main(args)
[rank0]:   File "/home/kw2501/local/torchchat/dist_run.py", line 444, in main
[rank0]:     decode_stage = PipelineStage(
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 1160, in __init__
[rank0]:     self.submod.to(self.device)
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1340, in to
[rank0]:     return self._apply(convert)
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]:     module._apply(fn)
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]:     module._apply(fn)
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 900, in _apply
[rank0]:     module._apply(fn)
[rank0]:   [Previous line repeated 1 more time]
[rank0]:   File "/home/kw2501/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/nn/modules/module.py", line 949, in _apply
[rank0]:     raise RuntimeError(
[rank0]: RuntimeError: _apply(): Couldn't swap Linear.weight

The related line of code in pipelining is:

self.submod.to(self.device)

Shall we not do this move? My understanding is that it is a convenience feature; but it seems not a big ask if the user were to do it.

Alternatives

An alternative would be for PipelilneStage to support dynamic input sizes. In this case (prefill vs decode), it is the sequence length that varies.

Additional context

cc: @H-Huang @wconstab

cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions