-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[dynamo] Use ExecutionRecorder only in root frame InstructionTranslator
#140152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140152
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 6681de8 with merge base f98c601 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Rebase |
|
This is mainly to separate out and test this behavior change, which is needed to enable patches down the line. |
…content id (#140436) In `match_nested_cell`, Dynamo tried to identify pre-existing captured cells by `(cell_name, id(cell_contents))`. This works in most cases, but as the test added in this patch shows, it's not a complete solution. This patch 1. changes `match_nested_cell` to `lookup_variable_for_captured_cell`, and does the lookup based on id of cell objects, not their contents. This requires plumbing a tuple of captured cell objects from different CPython versions all the way to `InstructionTranslator.__init__`, where we store a mapping from the ids of these cell objects, and use it later in `UserFunctionVariable.bind_args` to look for these unboxed cells. 2. builds off (1) -- rather than using a `VariableTracker` that represents the content of the unboxed cells, use `ClosureVariable`, which enables codegen in case these cells escape as closure of a `NestedUserFunctionVariable`. The patch adds a regression test for each of the scenarios above: 1. `test_write_to_cells_with_name_shadowing` where Dynamo mistakenly thought the program is writing to a cell captured by root frame (which it doesn't support atm), which resulted in ``` File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3340, in STORE_DEREF unimplemented("write to __closure__ while inlining") File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented raise Unsupported(msg, case_name=case_name) torch._dynamo.exc.Unsupported: write to __closure__ while inlining ``` 2. `test_existing_func_that_creates_capturing_nested_func` where Dynamo ended up trying to codegen a `NestedUserFunctionVariable` that captures a cell which was also captured by the root frame, so it was unboxed and ends up emitting `LOAD_DEREF` rather than `LOAD_FAST/LOAD_CLOSURE` during codegen, resulting in ``` File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 105, in _create_nested_fn func = FunctionType(code, f_globals, name, defaults, closure) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: arg 5 (closure) expected cell, found int ``` Pull Request resolved: #140436 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: #140330, #140152
pytorch#140435) Registed tensor hooks contain `NestedUserFunctionVariable` which might capture a `NewCellVariable` for cell objects created during Dynamo tracing, so we must make sure it doesn't get pruned away. Pull Request resolved: pytorch#140435 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects: 1. For cells captured and created by the root frame, represent them as their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and `STORE_DEREF` update directly, without going through `SideEffects`. 2. `ClosureVariable`: this is created when cells from (1) are captured by a newly created function Dynamo is about to inline. It's a handle with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1), to make `root_tx.symbolic_locals` up-to-date. 3. For cells that are captured by both the root frame and some pre-existing function Dynamo is about to inline, represent those cells as contents, and do not allow writes to them. Note that (2) and (3) are mainly to conform with (1) -- to make sure Dynamo has a consistent modeling of cells for the same cell objects. In this patch, we represent all of these cells as `NewCellVariable`. The main new code paths introduced are: - using `NewCellVariable` to model cell objects created by the root frame (the cells are passed in as input to `InstructionTranslator`), this is what allows us to get rid of all 3 legacy paths above. - adding a new `AutoDerefLocalSource` to deal with the python-code level (guards) and bytecode level (codegen) auto-dereferencing behavior, when accessing pre-existing python cells. This also involves a tiny update to guard manager generation. - plumbing some extra info into `LocalSource` and `CellVariable` so that we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`), which is important for readability, performance, and some assumptions `bytecode_transformation.py` makes. As a result, this patch removes a lot of the now-dead code paths and TODOs. Notably, it significantly simplified the `prune_dead_locals` function, which was duplicating a lot of the logic from `prune_dead_object_new`; this conveniently closes pytorch#137123. Pull Request resolved: pytorch#140153 Approved by: https://github.com/jansel ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435
…140154) Now that all cells are modeled as `NewCellVariable` in Dynamo, we no longer need to put cell variables into this special `closure_cells`, rather we just merge `closure_cells` with `symbolic_locals`. This allows us to merge and remove some code paths, notably make `LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF` the same for inlining or regular `InstructionTranslator`. Pull Request resolved: pytorch#140154 Approved by: https://github.com/jansel ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
…ytorch#140155) This is no longer needed now that we've replaced `ClosureVariable` with `NewCellVariable`, i.e., Dynamo now treats `LOAD_CLOSURE` the same as `LOAD_FAST`. Pull Request resolved: pytorch#140155 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153, pytorch#140154
…lator` (pytorch#140152) As title. This is effectively what ended up happening anyways since we always overwrite the record with the current frame's while propagating the exception upward in `InstructionTranslatorBase.run`. Pull Request resolved: pytorch#140152 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: pytorch#140330
…content id (pytorch#140436) In `match_nested_cell`, Dynamo tried to identify pre-existing captured cells by `(cell_name, id(cell_contents))`. This works in most cases, but as the test added in this patch shows, it's not a complete solution. This patch 1. changes `match_nested_cell` to `lookup_variable_for_captured_cell`, and does the lookup based on id of cell objects, not their contents. This requires plumbing a tuple of captured cell objects from different CPython versions all the way to `InstructionTranslator.__init__`, where we store a mapping from the ids of these cell objects, and use it later in `UserFunctionVariable.bind_args` to look for these unboxed cells. 2. builds off (1) -- rather than using a `VariableTracker` that represents the content of the unboxed cells, use `ClosureVariable`, which enables codegen in case these cells escape as closure of a `NestedUserFunctionVariable`. The patch adds a regression test for each of the scenarios above: 1. `test_write_to_cells_with_name_shadowing` where Dynamo mistakenly thought the program is writing to a cell captured by root frame (which it doesn't support atm), which resulted in ``` File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3340, in STORE_DEREF unimplemented("write to __closure__ while inlining") File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented raise Unsupported(msg, case_name=case_name) torch._dynamo.exc.Unsupported: write to __closure__ while inlining ``` 2. `test_existing_func_that_creates_capturing_nested_func` where Dynamo ended up trying to codegen a `NestedUserFunctionVariable` that captures a cell which was also captured by the root frame, so it was unboxed and ends up emitting `LOAD_DEREF` rather than `LOAD_FAST/LOAD_CLOSURE` during codegen, resulting in ``` File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 105, in _create_nested_fn func = FunctionType(code, f_globals, name, defaults, closure) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: arg 5 (closure) expected cell, found int ``` Pull Request resolved: pytorch#140436 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: pytorch#140330, pytorch#140152
pytorch#140435) Registed tensor hooks contain `NestedUserFunctionVariable` which might capture a `NewCellVariable` for cell objects created during Dynamo tracing, so we must make sure it doesn't get pruned away. Pull Request resolved: pytorch#140435 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects: 1. For cells captured and created by the root frame, represent them as their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and `STORE_DEREF` update directly, without going through `SideEffects`. 2. `ClosureVariable`: this is created when cells from (1) are captured by a newly created function Dynamo is about to inline. It's a handle with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1), to make `root_tx.symbolic_locals` up-to-date. 3. For cells that are captured by both the root frame and some pre-existing function Dynamo is about to inline, represent those cells as contents, and do not allow writes to them. Note that (2) and (3) are mainly to conform with (1) -- to make sure Dynamo has a consistent modeling of cells for the same cell objects. In this patch, we represent all of these cells as `NewCellVariable`. The main new code paths introduced are: - using `NewCellVariable` to model cell objects created by the root frame (the cells are passed in as input to `InstructionTranslator`), this is what allows us to get rid of all 3 legacy paths above. - adding a new `AutoDerefLocalSource` to deal with the python-code level (guards) and bytecode level (codegen) auto-dereferencing behavior, when accessing pre-existing python cells. This also involves a tiny update to guard manager generation. - plumbing some extra info into `LocalSource` and `CellVariable` so that we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`), which is important for readability, performance, and some assumptions `bytecode_transformation.py` makes. As a result, this patch removes a lot of the now-dead code paths and TODOs. Notably, it significantly simplified the `prune_dead_locals` function, which was duplicating a lot of the logic from `prune_dead_object_new`; this conveniently closes pytorch#137123. Pull Request resolved: pytorch#140153 Approved by: https://github.com/jansel ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435
…140154) Now that all cells are modeled as `NewCellVariable` in Dynamo, we no longer need to put cell variables into this special `closure_cells`, rather we just merge `closure_cells` with `symbolic_locals`. This allows us to merge and remove some code paths, notably make `LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF` the same for inlining or regular `InstructionTranslator`. Pull Request resolved: pytorch#140154 Approved by: https://github.com/jansel ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
…ytorch#140155) This is no longer needed now that we've replaced `ClosureVariable` with `NewCellVariable`, i.e., Dynamo now treats `LOAD_CLOSURE` the same as `LOAD_FAST`. Pull Request resolved: pytorch#140155 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153, pytorch#140154
…lator` (pytorch#140152) As title. This is effectively what ended up happening anyways since we always overwrite the record with the current frame's while propagating the exception upward in `InstructionTranslatorBase.run`. Pull Request resolved: pytorch#140152 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: pytorch#140330
…content id (pytorch#140436) In `match_nested_cell`, Dynamo tried to identify pre-existing captured cells by `(cell_name, id(cell_contents))`. This works in most cases, but as the test added in this patch shows, it's not a complete solution. This patch 1. changes `match_nested_cell` to `lookup_variable_for_captured_cell`, and does the lookup based on id of cell objects, not their contents. This requires plumbing a tuple of captured cell objects from different CPython versions all the way to `InstructionTranslator.__init__`, where we store a mapping from the ids of these cell objects, and use it later in `UserFunctionVariable.bind_args` to look for these unboxed cells. 2. builds off (1) -- rather than using a `VariableTracker` that represents the content of the unboxed cells, use `ClosureVariable`, which enables codegen in case these cells escape as closure of a `NestedUserFunctionVariable`. The patch adds a regression test for each of the scenarios above: 1. `test_write_to_cells_with_name_shadowing` where Dynamo mistakenly thought the program is writing to a cell captured by root frame (which it doesn't support atm), which resulted in ``` File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3340, in STORE_DEREF unimplemented("write to __closure__ while inlining") File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented raise Unsupported(msg, case_name=case_name) torch._dynamo.exc.Unsupported: write to __closure__ while inlining ``` 2. `test_existing_func_that_creates_capturing_nested_func` where Dynamo ended up trying to codegen a `NestedUserFunctionVariable` that captures a cell which was also captured by the root frame, so it was unboxed and ends up emitting `LOAD_DEREF` rather than `LOAD_FAST/LOAD_CLOSURE` during codegen, resulting in ``` File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 105, in _create_nested_fn func = FunctionType(code, f_globals, name, defaults, closure) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: arg 5 (closure) expected cell, found int ``` Pull Request resolved: pytorch#140436 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: pytorch#140330, pytorch#140152
pytorch#140435) Registed tensor hooks contain `NestedUserFunctionVariable` which might capture a `NewCellVariable` for cell objects created during Dynamo tracing, so we must make sure it doesn't get pruned away. Pull Request resolved: pytorch#140435 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436
In addition to `NewCellVariable`, Dynamo has 3 ways of modeling cell objects: 1. For cells captured and created by the root frame, represent them as their contents in `root_tx.symbolic_locals`, which `LOAD_DEREF` and `STORE_DEREF` update directly, without going through `SideEffects`. 2. `ClosureVariable`: this is created when cells from (1) are captured by a newly created function Dynamo is about to inline. It's a handle with a name that redirects `LOAD_DEREF` and `STORE_DEREF` back (1), to make `root_tx.symbolic_locals` up-to-date. 3. For cells that are captured by both the root frame and some pre-existing function Dynamo is about to inline, represent those cells as contents, and do not allow writes to them. Note that (2) and (3) are mainly to conform with (1) -- to make sure Dynamo has a consistent modeling of cells for the same cell objects. In this patch, we represent all of these cells as `NewCellVariable`. The main new code paths introduced are: - using `NewCellVariable` to model cell objects created by the root frame (the cells are passed in as input to `InstructionTranslator`), this is what allows us to get rid of all 3 legacy paths above. - adding a new `AutoDerefLocalSource` to deal with the python-code level (guards) and bytecode level (codegen) auto-dereferencing behavior, when accessing pre-existing python cells. This also involves a tiny update to guard manager generation. - plumbing some extra info into `LocalSource` and `CellVariable` so that we can still emit `LOAD_DEREF`, `STORE_DEREF`, `LOAD_CLOSURE` (instead of `make_cell`, `cell_contents` attribute access, and `LOAD_FAST`), which is important for readability, performance, and some assumptions `bytecode_transformation.py` makes. As a result, this patch removes a lot of the now-dead code paths and TODOs. Notably, it significantly simplified the `prune_dead_locals` function, which was duplicating a lot of the logic from `prune_dead_object_new`; this conveniently closes pytorch#137123. Pull Request resolved: pytorch#140153 Approved by: https://github.com/jansel ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435
…140154) Now that all cells are modeled as `NewCellVariable` in Dynamo, we no longer need to put cell variables into this special `closure_cells`, rather we just merge `closure_cells` with `symbolic_locals`. This allows us to merge and remove some code paths, notably make `LOAD_CLOSURE` the same as `LOAD_FAST`, and `LOAD_DEREF` & `STORE_DEREF` the same for inlining or regular `InstructionTranslator`. Pull Request resolved: pytorch#140154 Approved by: https://github.com/jansel ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153
…ytorch#140155) This is no longer needed now that we've replaced `ClosureVariable` with `NewCellVariable`, i.e., Dynamo now treats `LOAD_CLOSURE` the same as `LOAD_FAST`. Pull Request resolved: pytorch#140155 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153, pytorch#140154
…ytorch#140155) This is no longer needed now that we've replaced `ClosureVariable` with `NewCellVariable`, i.e., Dynamo now treats `LOAD_CLOSURE` the same as `LOAD_FAST`. Pull Request resolved: pytorch#140155 Approved by: https://github.com/jansel, https://github.com/williamwen42 ghstack dependencies: pytorch#140330, pytorch#140152, pytorch#140436, pytorch#140435, pytorch#140153, pytorch#140154
Stack from ghstack (oldest at bottom):
name_stackcode paths insymbolic_convert.py#140155closure_cellsand merge/remove code paths #140154NewCellVariable#140153prune_dead_object_new#140435ExecutionRecorderonly in root frameInstructionTranslator#140152DynamoFrameTypetype above Python frame object #140330As title. This is effectively what ended up happening anyways since we
always overwrite the record with the current frame's while propagating
the exception upward in
InstructionTranslatorBase.run.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames