Skip to content

Conversation

@StrongerXi
Copy link
Contributor

@StrongerXi StrongerXi commented Nov 8, 2024

Stack from ghstack (oldest at bottom):

In addition to NewCellVariable, Dynamo has 3 ways of modeling cell objects:

  1. For cells captured and created by the root frame, represent them as
    their contents in root_tx.symbolic_locals, which LOAD_DEREF and
    STORE_DEREF update directly, without going through SideEffects.
  2. ClosureVariable: this is created when cells from (1) are captured
    by a newly created function Dynamo is about to inline. It's a handle
    with a name that redirects LOAD_DEREF and STORE_DEREF back (1),
    to make root_tx.symbolic_locals up-to-date.
  3. For cells that are captured by both the root frame and some
    pre-existing function Dynamo is about to inline, represent those
    cells as contents, and do not allow writes to them.

Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.

In this patch, we represent all of these cells as NewCellVariable. The
main new code paths introduced are:

  • using NewCellVariable to model cell objects created by the root
    frame (the cells are passed in as input to InstructionTranslator),
    this is what allows us to get rid of all 3 legacy paths above.
  • adding a new AutoDerefLocalSource to deal with the python-code
    level (guards) and bytecode level (codegen) auto-dereferencing
    behavior, when accessing pre-existing python cells. This also
    involves a tiny update to guard manager generation.
  • plumbing some extra info into LocalSource and CellVariable so that
    we can still emit LOAD_DEREF, STORE_DEREF, LOAD_CLOSURE (instead
    of make_cell, cell_contents attribute access, and LOAD_FAST),
    which is important for readability, performance, and some
    assumptions bytecode_transformation.py makes.

As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the prune_dead_locals
function, which was duplicating a lot of the logic from
prune_dead_object_new; this conveniently closes #137123.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140153

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit ec4e620 with merge base f98c601 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@StrongerXi
Copy link
Contributor Author

StrongerXi commented Nov 8, 2024

Tests are failing, here's what I've found so far:

  1. Python 3.9 and 3.10 fail because the frame object we pass to Dynamo's eval frame callback doesn't have f_func: FunctionType (this patch uses it to retrieve the cells captured by the root frame).
  2. This patch exposed a bug in Dynamo's handling of out= keyword for torch operators like sort -- the Python semantics is in-place mutation on the underlying tensor object, but in Dynamo, we create a new TensorVariable (through wrap_fx_proxy) and tries to replace instances of the old TensorVariable in symbolic_locals with this new one. The replacement never accounted for variables that are not in symbolic_locals, but reachable from it (e.g., within a TupleVariable).
  3. A torch._dynamo.exc.Unsupported: reconstruct: NewCellVariable() failure that's 3.11 specific. I have some idea and will look more. This patch exposed a small bug in OutputGraph codegen order -- in general codegen_save_tempvars should run first, to allocate and cache source for newly created objects.

@StrongerXi
Copy link
Contributor Author

I'll try to create separate issues for (2) and (3) above, and fix them. (2) feels a little annoying.

[ghstack-poisoned]
[ghstack-poisoned]
@StrongerXi StrongerXi added the topic: not user facing topic category label Nov 11, 2024
[ghstack-poisoned]
@StrongerXi
Copy link
Contributor Author

StrongerXi commented Nov 11, 2024

Plumb closure: Tuple[CellType] from different versions of CPython all the way to InstructionTranslator.
Moved to #140436 as a somewhat orthogonal change.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@StrongerXi
Copy link
Contributor Author

Rebase.

[ghstack-poisoned]
[ghstack-poisoned]
@StrongerXi StrongerXi changed the title [dynamo] Represent all cells as CellVariable [dynamo] Represent all cells as NewCellVariable Nov 13, 2024
Comment on lines +236 to +237
@dataclasses.dataclass(frozen=True)
class AutoDerefLocalSource(ChainedSource):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly to make root-frame cells (the ones Python generates LOAD_DEREF and STORE_DEREF for) play well with guards. I think we can remove it if we make our f_locals (or whatever it becomes) contain cell objects without dereferencing them, which might be do-able as part of #140063 (comment)? @jansel @williamwen42

Comment on lines +889 to +896
elif istype(source, AutoDerefLocalSource):
# Guard checks run on f_locals, in which the python level
# auto-dereferenced cell objects are also dereferenced (e.g., rather
# than `f_locals` being `{ 'cell' : <cell object of int> }`, it'll
# be `{ 'cell' : <int> }`. So the guard manager is the same as the
# base guard manager.
assert isinstance(base_guard_manager, GuardManager) # tame mypy
out = base_guard_manager
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +974 to +992
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", L_y_: "f32[s0, s1]", s2: "Sym(s2)", L_x_: "f32[s2, s0]"):
l_y_ = L_y_
l_x_ = L_x_
wrap_body_1 = self.wrap_body_1
wrap = torch.ops.higher_order.wrap(wrap_body_1, s0, s1, l_x_, s2, l_y_); wrap_body_1 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_1, s2, s0, l_x_, s1, l_y_); wrap_body_1 = s2 = s0 = l_x_ = s1 = l_y_ = None
getitem: "f32[s2, s1]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_1(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
def forward(self, s2: "Sym(s2)", s0: "Sym(s0)", l_x_: "f32[s2, s0]", s1: "Sym(s1)", l_y_: "f32[s0, s1]"):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, s0, s1, l_x_, s2, l_y_); wrap_body_0 = s0 = s1 = l_x_ = s2 = l_y_ = None
getitem: "f32[s0, s2]" = wrap[0]; wrap = None
wrap = torch.ops.higher_order.wrap(wrap_body_0, s2, s0, l_x_, s1, l_y_); wrap_body_0 = s2 = s0 = l_x_ = s1 = l_y_ = None
getitem: "f32[s2, s1]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, s0: "Sym(s0)", s1: "Sym(s1)", l_x_: "f32[s0, s1]", s2: "Sym(s2)", l_y_: "f32[s1, s2]"):
matmul: "f32[s0, s2]" = l_x_ @ l_y_; l_x_ = l_y_ = None
def forward(self, s2: "Sym(s2)", s0: "Sym(s0)", l_x_: "f32[s2, s0]", s1: "Sym(s1)", l_y_: "f32[s0, s1]"):
matmul: "f32[s2, s1]" = l_x_ @ l_y_; l_x_ = l_y_ = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydwu4 does this change matter?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order doesn't matter for this hop as long as it's deterministic.

[ghstack-poisoned]
[ghstack-poisoned]
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Nov 15, 2024
…140154)

Now that all cells are modeled as `NewCellVariable` in Dynamo, we no
longer need to put cell variables into this special `closure_cells`,
rather we just merge `closure_cells` with `symbolic_locals`.

This allows us to merge and remove some code paths, notably make
`LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF`
the same for inlining or regular `InstructionTranslator`.

Pull Request resolved: pytorch#140154
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
smalltalkman pushed a commit to smalltalkman/pytorch that referenced this pull request Nov 15, 2024
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects:
1. For cells captured and created by the root frame, represent them as
   their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and
   `STORE_DEREF` update directly, without going through `SideEffects`.
2. `ClosureVariable`: this is created when cells from (1) are captured
   by a newly created function Dynamo is about to inline. It's a handle
   with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1),
   to make `root_tx.symbolic_locals` up-to-date.
3. For cells that are captured by both the root frame and some
   pre-existing function Dynamo is about to inline, represent those
   cells as contents, and do not allow writes to them.

Note that (2) and (3) are mainly to conform with (1) -- to make sure
Dynamo has a consistent modeling of cells for the same cell objects.

In this patch, we represent all of these cells as `NewCellVariable`. The
main new code paths introduced are:
- using `NewCellVariable` to model cell objects created by the root
  frame (the cells are passed in as input to `InstructionTranslator`),
  this is what allows us to get rid of all 3 legacy paths above.
- adding a new `AutoDerefLocalSource` to deal with the python-code
  level (guards) and bytecode level (codegen) auto-dereferencing
  behavior, when accessing pre-existing python cells. This also
  involves a tiny update to guard manager generation.
- plumbing some extra info into `LocalSource` and `CellVariable` so that
  we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead
  of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`),
  which is important for readability, performance, and some
  assumptions `bytecode_transformation.py` makes.

As a result, this patch removes a lot of the now-dead code paths and
TODOs. Notably, it significantly simplified the `prune_dead_locals`
function, which was duplicating a lot of the logic from
`prune_dead_object_new`; this conveniently closes pytorch#137123.

Pull Request resolved: pytorch#140153
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
…140154)

Now that all cells are modeled as `NewCellVariable` in Dynamo, we no
longer need to put cell variables into this special `closure_cells`,
rather we just merge `closure_cells` with `symbolic_locals`.

This allows us to merge and remove some code paths, notably make
`LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF`
the same for inlining or regular `InstructionTranslator`.

Pull Request resolved: pytorch#140154
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
pobin6 pushed a commit to pobin6/pytorch that referenced this pull request Dec 5, 2024
@github-actions github-actions bot deleted the gh/StrongerXi/28/head branch December 19, 2024 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Investigate making prune_dead_locals more aggressive by tracing liveness from variables in SideEffects

5 participants