-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Allow parallelize_module to get device_mesh from ambient context #134247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134247
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 47952c0 with merge base 1aac1ff ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| device_mesh: DeviceMesh, | ||
| parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]], | ||
| device_mesh: Optional[DeviceMesh] = None, | ||
| parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe explain why you want this?
why not just skip calling parallelize_module at all?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are so fast in reviewing it, even before I finish writing the PR desc.
Now the motivation is added :)
|
Per your example, how do you annotate the input for your nn.module? So you will call |
|
@fduwjj The manipulation to inputs can be written in the forward function (just like how we manipulate the output above.) In fact, we can manipulate activations too. |
|
@kwen2501 Another question I have is, why can't users directly write their module using DTensor? because we don't have a nn.Linear which is directly built on DTensor, so we need this swap? Since you already do lots of DTensor operations. For the purpose of flexibility, why can't users do: I mean I don't have answer here, but I am just thinking what is most flexible way for users. And this way, users just swap whatever part to DTensor but they need to take care not to mix Tensor and DTensor together. |
They can do this, just not that modular, as compared to calling |
wanchaol
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the change looks reasonable to me. Please update documentations to reflect these fields can be optional (and the behavior when they are optional), and add tests
|
|
||
| if parallelize_plan is None: | ||
| warnings.warn( | ||
| "No parallelize_plan is provided, so parallelize_module does nothing." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: maybe add does nothing to module name since you are likely to have multiple modules.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't fully appreciate the change, in that:
In your example, since we know class FeedForward(nn.Module) will be init in a distributed way, and needs to be called under some device_mesh context, why can't we put a DeviceMesh into the init args? i.e.
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs, mesh: DeviceMesh) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, mesh, Colwise)
self.w2 = parallelize_module(w2, mesh, Rowwise)
self.w3 = parallelize_module(w3, mesh, Colwise)
|
@tianyu-l Good question. I guess there are a couple reasons: |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there are a couple reasons:
(1) for the model's creation signature to remain the same as non-parallel case.
(2) for the model's code to look less heavy.
(3) avoid having to pass mesh from parent module to child modules.
Re (1): I don't know if this is pro or con -- having the mesh in signature could help user be aware they are using a non-parallel module.
Re (2) and (3): I agree
Let's also consider potential downside:
(a) parallelize_plan should be required, but now we are sacrificing this to allow for a more flexible device_mesh. To me this is a trade-off, rathe than pure win.
(b) inline commented: do we always get the right device_mesh from context, when there are n-dimensional root DeviceMesh?
(c) this is a one-way move: after adopting this change, going back (for other reasons we might not be able to forsee) would be BC-breaking.
I'm a bit concerned about (a) and (b). I wonder what you and other people think.
| """ | ||
| torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") | ||
|
|
||
| device_mesh = device_mesh or _mesh_resources.get_current_mesh() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have nD DeviceMesh, what this line of code will retrieve? Would it intelligently fetch the TP mesh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whether the code is:
parallelize_module(mod, mesh, plan)
or
with mesh:
parallelize_module(mod, plan)
the user has responsibility to make sure mesh is a correct, intended one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If mesh is wrong, the above two pieces of code would fail in the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sense to me.
but would this code run? I thought it at least should be:
parallelize_module(mod, parallelize_plan=plan) which doesn't make it that simple compared with parallelize_module(mod, mesh, plan)
|
Re (a): Also want to note that APIs like |
When I say required, I mean required by the method signature.
They look different to me. For I'd be less concerned if the order of |
|
Yeah, I agree that it would have been nicer if the order of |
|
@tianyu-l Regardless of the example above, this may be the biggest motivation behind this change:
|
I'm OK with setting |
|
I still think the API signature change sounds good to me. Please add tests and update documentation to reflect the signature change
@tianyu-l I think it's ok for this to be default to None and throw a warning to the user when it's the case. The API intention is to parallelize a module with a plan/style, if user intentionally don't pass a plan then we could return the original (untouched) module, which looks a reasonable behavior to me.
Unfortunately we can't switch the argument order (that considered as a BC breaking change) in python. I feel the |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot revert -m "Broke lint, which one can clearly see in PR CI https://github.com/pytorch/pytorch/actions/runs/11112138513/job/30873604386 " -c ignoredsignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@kwen2501 your PR has been successfully reverted. |
…ext (#134247)" This reverts commit 80e7478. Reverted #134247 on behalf of https://github.com/malfet due to Broke lint, which one can clearly see in PR CI https://github.com/pytorch/pytorch/actions/runs/11112138513/job/30873604386 ([comment](#134247 (comment)))
…orch#134247) This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one. Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. (The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.) For example: ``` class FeedForward(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) self.w1 = parallelize_module(w1, Colwise) self.w2 = parallelize_module(w2, Rowwise) self.w3 = parallelize_module(w3, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) # y is a DTensor with Partial placement; we can return it as is. return y # Or we can convert it to Replicate -- there is modeling flexibility here. return y.redistribute(Replicate()) with device_mesh: model = FeedForward(config) # Now model is a model parallelized onto device_mesh y = model(x) ``` The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context. Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case. Pull Request resolved: pytorch#134247 Approved by: https://github.com/tianyu-l
…ext (pytorch#134247)" This reverts commit 80e7478. Reverted pytorch#134247 on behalf of https://github.com/malfet due to Broke lint, which one can clearly see in PR CI https://github.com/pytorch/pytorch/actions/runs/11112138513/job/30873604386 ([comment](pytorch#134247 (comment)))
…ontext"
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
…ontext"
This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.
Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.
(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
```
class FeedForward(nn.Module):
def __init__(self, config: TransformerArgs) -> None:
super().__init__()
w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
self.w1 = parallelize_module(w1, Colwise)
self.w2 = parallelize_module(w2, Rowwise)
self.w3 = parallelize_module(w3, Colwise)
def forward(self, x: Tensor) -> Tensor:
y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
# y is a DTensor with Partial placement; we can return it as is.
return y
# Or we can convert it to Replicate -- there is modeling flexibility here.
return y.redistribute(Replicate())
with device_mesh:
model = FeedForward(config)
# Now model is a model parallelized onto device_mesh
y = model(x)
```
The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.
Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.
cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
This PR is for supporting calling
parallelize_modulefrom within a model definition, making the model a parallel one.Calling
parallelize_moduleis an alternative to maintaining a set ofColumnWiseLinear,RowWiseLinear, etc, while still being able to directly author a parallel model.(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)
For example:
The
device_meshactually used forparallelize_modulewould be retrieved from the ambient context.Calling
parallelize_modulefrom within model hierarchy also saves the use of FQNs as in the out-of-model annotation case.cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o