Skip to content

Conversation

@kwen2501
Copy link
Collaborator

@kwen2501 kwen2501 commented Aug 22, 2024

Stack from ghstack (oldest at bottom):

This PR is for supporting calling parallelize_module from within a model definition, making the model a parallel one.

Calling parallelize_module is an alternative to maintaining a set of ColumnWiseLinear, RowWiseLinear, etc, while still being able to directly author a parallel model.

(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)

For example:

class FeedForward(nn.Module):
    def __init__(self, config: TransformerArgs) -> None:
        super().__init__()
        w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w1 = parallelize_module(w1, Colwise)
        self.w2 = parallelize_module(w2, Rowwise)
        self.w3 = parallelize_module(w3, Colwise)

    def forward(self, x: Tensor) -> Tensor:
        y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # y is a DTensor with Partial placement; we can return it as is.
        return y
        # Or we can convert it to Replicate -- there is modeling flexibility here.
        return y.redistribute(Replicate())


with device_mesh:
    model = FeedForward(config)
    # Now model is a model parallelized onto device_mesh

y = model(x)

The device_mesh actually used for parallelize_module would be retrieved from the ambient context.

Calling parallelize_module from within model hierarchy also saves the use of FQNs as in the out-of-model annotation case.

cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134247

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 47952c0 with merge base 1aac1ff (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 22, 2024
kwen2501 added a commit that referenced this pull request Aug 22, 2024
device_mesh: DeviceMesh,
parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
device_mesh: Optional[DeviceMesh] = None,
parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe explain why you want this?
why not just skip calling parallelize_module at all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are so fast in reviewing it, even before I finish writing the PR desc.
Now the motivation is added :)

@kwen2501 kwen2501 added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Aug 22, 2024
@kwen2501 kwen2501 requested review from awgu, fduwjj and wz337 August 22, 2024 18:39
@fduwjj
Copy link
Contributor

fduwjj commented Aug 22, 2024

Per your example, how do you annotate the input for your nn.module? So you will call parallelize_module(FeedForward, input_plan) something like that? So aside from avoiding specifying long FQN list, what other benefits are there?

@kwen2501
Copy link
Collaborator Author

@fduwjj The manipulation to inputs can be written in the forward function (just like how we manipulate the output above.) In fact, we can manipulate activations too.

    def forward(self, x: Tensor) -> Tensor:
        # Manipulate inputs
        x = x.redistribute(...)
        y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
        ....

@kwen2501 kwen2501 changed the title Make device_mesh an optional argument of parallelize_module Allow parallelize_module to get device_mesh from ambient context Aug 22, 2024
@fduwjj
Copy link
Contributor

fduwjj commented Aug 22, 2024

@kwen2501 Another question I have is, why can't users directly write their module using DTensor? because we don't have a nn.Linear which is directly built on DTensor, so we need this swap?

Since you already do lots of DTensor operations. For the purpose of flexibility, why can't users do:

self.w1 = nn.Linear(...)
self.w1.weight = nn.Parameter(DTensor_shard(self.w1.weight))

I mean I don't have answer here, but I am just thinking what is most flexible way for users. And this way, users just swap whatever part to DTensor but they need to take care not to mix Tensor and DTensor together.

@kwen2501
Copy link
Collaborator Author

kwen2501 commented Aug 22, 2024

self.w1 = nn.Linear(...)
self.w1.weight = nn.Parameter(DTensor_shard(self.w1.weight))

They can do this, just not that modular, as compared to calling parallelize_module or using ColumnWiseLinear class.
The assumption is that authors still want some abstraction than writing programs from ground up using DTensors.

@wanchaol wanchaol requested review from tianyu-l and yifuwang August 22, 2024 20:54
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the change looks reasonable to me. Please update documentations to reflect these fields can be optional (and the behavior when they are optional), and add tests


if parallelize_plan is None:
warnings.warn(
"No parallelize_plan is provided, so parallelize_module does nothing."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: maybe add does nothing to module name since you are likely to have multiple modules.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't fully appreciate the change, in that:
In your example, since we know class FeedForward(nn.Module) will be init in a distributed way, and needs to be called under some device_mesh context, why can't we put a DeviceMesh into the init args? i.e.

class FeedForward(nn.Module):
    def __init__(self, config: TransformerArgs, mesh: DeviceMesh) -> None:
        super().__init__()
        w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w1 = parallelize_module(w1, mesh, Colwise)
        self.w2 = parallelize_module(w2, mesh, Rowwise)
        self.w3 = parallelize_module(w3, mesh, Colwise)

@kwen2501
Copy link
Collaborator Author

@tianyu-l Good question. I guess there are a couple reasons:
(1) for the model's creation signature to remain the same as non-parallel case.
(2) for the model's code to look less heavy.
(3) avoid having to pass mesh from parent module to child modules.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there are a couple reasons:
(1) for the model's creation signature to remain the same as non-parallel case.
(2) for the model's code to look less heavy.
(3) avoid having to pass mesh from parent module to child modules.

Re (1): I don't know if this is pro or con -- having the mesh in signature could help user be aware they are using a non-parallel module.
Re (2) and (3): I agree

Let's also consider potential downside:
(a) parallelize_plan should be required, but now we are sacrificing this to allow for a more flexible device_mesh. To me this is a trade-off, rathe than pure win.
(b) inline commented: do we always get the right device_mesh from context, when there are n-dimensional root DeviceMesh?
(c) this is a one-way move: after adopting this change, going back (for other reasons we might not be able to forsee) would be BC-breaking.

I'm a bit concerned about (a) and (b). I wonder what you and other people think.

"""
torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module")

device_mesh = device_mesh or _mesh_resources.get_current_mesh()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have nD DeviceMesh, what this line of code will retrieve? Would it intelligently fetch the TP mesh?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whether the code is:

parallelize_module(mod, mesh, plan)

or

with mesh:
    parallelize_module(mod, plan)

the user has responsibility to make sure mesh is a correct, intended one.

Copy link
Collaborator Author

@kwen2501 kwen2501 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If mesh is wrong, the above two pieces of code would fail in the same way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this makes sense to me.

but would this code run? I thought it at least should be:
parallelize_module(mod, parallelize_plan=plan) which doesn't make it that simple compared with parallelize_module(mod, mesh, plan)

@kwen2501
Copy link
Collaborator Author

Re (a):
mesh is still in fact required. If the result of
device_mesh or _mesh_resources.get_current_mesh() is None, I think an error will be returned: "No active device mesh is found."

Also want to note that APIs like distribute_tensor or distribute_module also support such usage (which to me seems like a helpful feature).

@tianyu-l
Copy link
Contributor

Re (a): mesh is still in fact required. If the result of device_mesh or _mesh_resources.get_current_mesh() is None, I think an error will be returned: "No active device mesh is found."

When I say required, I mean required by the method signature.

Also want to note that APIs like distribute_tensor or distribute_module also support such usage (which to me seems like a helpful feature).

They look different to me. For distribute_tensor and distribute_module, each optional arg has meaningful default behavior. But for parallelize_plan, it in nature is not optional. Here we are only marking it optional because we want to mark another previous arg optional (device_mesh), which is making shortcuts to bypass Python constraints.

I'd be less concerned if the order of device_mesh and parallelize_plan is switched and only device_mesh is marked optional.

@kwen2501
Copy link
Collaborator Author

Yeah, I agree that it would have been nicer if the order of device_mesh and parallelize_plan is switched.
In the PR, I supported parallelize_plan=None by saying this is a no-op. Would that lessen your concern?

@kwen2501
Copy link
Collaborator Author

kwen2501 commented Aug 26, 2024

@tianyu-l Regardless of the example above, this may be the biggest motivation behind this change:
a separation between distributed runtime and model layout description.

device_mesh represents a runtime, in that it would have been embodied to concrete devices, and a communicator among them.

parallelize_plan represents how data of a model will be sharded, column wise or row wise, etc. It is a style of data layout that does not necessarily have to bind to a concrete device mesh at the time of description.

@kwen2501 kwen2501 requested a review from tianyu-l August 26, 2024 18:15
@tianyu-l
Copy link
Contributor

In the PR, I supported parallelize_plan=None by saying this is a no-op.

parallelize_plan is the core part of this method, and cannot be None -- calling the function should fail without a parallelize_plan.

I'm OK with setting device_mesh = None and get it from ambient context -- I think that's good change. Do you think we can mark device_mesh: DeviceMesh | None for this purpose? (yeah I hope we could switch the order of these two)

@kwen2501
Copy link
Collaborator Author

@wanchaol Sorry to bother you.
@tianyu-l would like a sign-off on the API signature change.
Could you please sign it off (or reject it)? Thanks!

@wanchaol
Copy link
Collaborator

I still think the API signature change sounds good to me. Please add tests and update documentation to reflect the signature change

parallelize_plan is the core part of this method, and cannot be None -- calling the function should fail without a parallelize_plan.

@tianyu-l I think it's ok for this to be default to None and throw a warning to the user when it's the case. The API intention is to parallelize a module with a plan/style, if user intentionally don't pass a plan then we could return the original (untouched) module, which looks a reasonable behavior to me.

I hope we could switch the order of these two

Unfortunately we can't switch the argument order (that considered as a BC breaking change) in python. I feel the device_mesh being the second argument also matches the distribute_tensor/module signature. The downside here is mainly it forced the parallelize_plan to be optional, which I feel we have a reasonable default here as mentioned above.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor

malfet commented Sep 30, 2024

@pytorchbot revert -m "Broke lint, which one can clearly see in PR CI https://github.com/pytorch/pytorch/actions/runs/11112138513/job/30873604386 " -c ignoredsignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@kwen2501 your PR has been successfully reverted.

AnantGulati pushed a commit to AnantGulati/pytorch that referenced this pull request Oct 2, 2024
…orch#134247)

This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.

Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model.

(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)

For example:
```
class FeedForward(nn.Module):
    def __init__(self, config: TransformerArgs) -> None:
        super().__init__()
        w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w1 = parallelize_module(w1, Colwise)
        self.w2 = parallelize_module(w2, Rowwise)
        self.w3 = parallelize_module(w3, Colwise)

    def forward(self, x: Tensor) -> Tensor:
        y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # y is a DTensor with Partial placement; we can return it as is.
        return y
        # Or we can convert it to Replicate -- there is modeling flexibility here.
        return y.redistribute(Replicate())

with device_mesh:
    model = FeedForward(config)
    # Now model is a model parallelized onto device_mesh

y = model(x)

```

The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.

Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.

Pull Request resolved: pytorch#134247
Approved by: https://github.com/tianyu-l
…ontext"


This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.

Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. 

(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)

For example:
```
class FeedForward(nn.Module):
    def __init__(self, config: TransformerArgs) -> None:
        super().__init__()
        w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w1 = parallelize_module(w1, Colwise)
        self.w2 = parallelize_module(w2, Rowwise)
        self.w3 = parallelize_module(w3, Colwise)

    def forward(self, x: Tensor) -> Tensor:
        y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # y is a DTensor with Partial placement; we can return it as is.
        return y
        # Or we can convert it to Replicate -- there is modeling flexibility here.
        return y.redistribute(Replicate())


with device_mesh:
    model = FeedForward(config)
    # Now model is a model parallelized onto device_mesh

y = model(x)

```

The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.

Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.

cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

…ontext"


This PR is for supporting calling `parallelize_module` from within a model definition, making the model a parallel one.

Calling `parallelize_module` is an alternative to maintaining a set of `ColumnWiseLinear`, `RowWiseLinear`, etc, while still being able to directly author a parallel model. 

(The motivation for authoring a parallel model is that there may be other distributed operations, which may not be easily captured by any module, see the forward function below. Alternatively speaking, the purpose is to exploit the expressiveness of DTensor -- we need to first create DTensors before calling ops on them. Having parallelized modules in model is one way of creating DTensors.)

For example:
```
class FeedForward(nn.Module):
    def __init__(self, config: TransformerArgs) -> None:
        super().__init__()
        w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
        w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
        self.w1 = parallelize_module(w1, Colwise)
        self.w2 = parallelize_module(w2, Rowwise)
        self.w3 = parallelize_module(w3, Colwise)

    def forward(self, x: Tensor) -> Tensor:
        y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x))
        # y is a DTensor with Partial placement; we can return it as is.
        return y
        # Or we can convert it to Replicate -- there is modeling flexibility here.
        return y.redistribute(Replicate())


with device_mesh:
    model = FeedForward(config)
    # Now model is a model parallelized onto device_mesh

y = model(x)

```

The `device_mesh` actually used for `parallelize_module` would be retrieved from the ambient context.

Calling `parallelize_module` from within model hierarchy also saves the use of *FQNs* as in the out-of-model annotation case.

cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Oct 8, 2024
@kwen2501
Copy link
Collaborator Author

kwen2501 commented Oct 9, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/kwen2501/47/head branch November 8, 2024 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category Reverted suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants