-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
π The feature, motivation and pitch
Today PipelineStage's init method would dry run the module with the example input:
pytorch/torch/distributed/pipelining/stage.py
Line 1270 in 48d18fb
| self.outputs = self.submod(*self.inputs) |
This demands extra memory and may OOM for large models which additionally requires TP/FSDP or Activation Checkpointing to keep the memory envelope low. (But they might not have been applied at this point of pipeline stage creation.)
Alternatives
The dryrun is for generating output_args, the shape of which we rely on to create gradient recv buffers during backward.
A workaround would be for user to provide output_args to PipelineStage init but it is not ergonomic.
Also, inference runs do not have backward to worry about.
Additional context
cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o