-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] add support for breaks and continues #21692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't finish the break_continue_transform.cpp but I won't be free for a bit so here is what I have so far.
|
|
||
| for (auto& v : true_vars->definedVariables()) { | ||
| if (false_vars->findInAnyFrame(v)) { | ||
| if (false_vars->findInAnyFrame(v) || false_escape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is inserting extra uninitialized nodes/loop carried nodes for variables used in blocks edited by the break/continue pass that are not otherwise used outside the block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes but we run DCE, and this is the case generally with emitting if outputs that they may not be used otherwise. I don't think it can be emitting extra loop carried nodes because in this case the variable is not defined outside of the loop
| LoopTransformer(std::shared_ptr<Graph> graph_, Transform transform_) | ||
| : graph(std::move(graph_)) { | ||
| WithInsertPoint guard(graph->block()->nodes().front()); | ||
| true_val = graph->insertConstant(true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use _ for class members.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am having trouble following the logic in the pass. Part of the problem is that it isn't following a consistent representation for what a prim::Loop is. This pass starts with the loop having a separate guard block, and then ends without one, but also having modified the block. I'd expect instead that it works on the Loops pre-lowering and let ssa conversion handle the lowering.
torch/csrc/jit/ir_views.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this was a mistake, it's accessing maxTripCount
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest moving the cleanup/refactoring stuff into a separate PR.
1cda8e7 to
73cea49
Compare
|
failures are unrelated... |
ZolotukhinM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a really big PR :) I tried to walk through it and I think eventually I understood most of it, however, I really think we need to try to simplify it. I'm quite terrified to think of what would happen if there is a bug in this code and you are not around to debug it.
Some more specific suggestions/remarks from me are inline.
torch/csrc/jit/ir_views.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest moving the cleanup/refactoring stuff into a separate PR.
| InlineLoopCondition(graph); | ||
| EraseLoadsStores erase_loads_stores; | ||
| erase_loads_stores.run(graph); | ||
| TransformExits(graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably have discussed this with @zdevito, but are we sure that we want to introduce BlockExit and LoopExit as a primary IR nodes? It looks like they are just used as a metadata to pass some info between these passes - could we just perform a book-keeping in some map and pass it instead? IR transformations to keep these nodes are responsibly for a bulk part of complexity in this PR in my opinion.
NB: Load and Store are in somewhat in similar position, but I think they are fine because 1) they have much clearer semantics, 2) they don't need such sophisticated IR manipulations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What book keeping would you like exactly? I think passing stateful maps around will be more complicated than something that is explicit in the IR & human readable from one pass to the next.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The book-keeping I'm referring to is basically arguments of BlockExit/LoopExit that you encode in IR. The problem with the current approach is that even though it's supposed to be human readable, it really is not. I will have hard time understanding what a specific BlockExit means if I see it in IR. Does it depend on the context? What do arguments mean? What does it actually do (is it a jump?)?
Having stateful maps is generally bad, but if this is just between a couple of passes that are already dependent on each other and always run together, that might be not the worst.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BlockExit jumps to the most recent block, LoopExit jumps to the most recent loop block. I'm happy to consider other options that would be more clear. Maybe other there other names you think would be better ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just set the exit behavior as an attribute on the prim::Return node that already exists at the end of every block. It seems redundant to have both outputs of the block, and separate node temporarily the block to express those outputs. The attribute can just be a string specifying one of attr::exit_level = {"block", "loop", "function"}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ATM, there needs to be a distinction between the exits that are added in convert_to_ssa.cpp and the existing prim::Return for a couple reasons.
- We emit ifs with block outputs already added in compiler.cpp because of
andandor. After convert_to_ssa and before exit_transforms is run it would look like.
%z : bool = prim::If(%5)
block0():
%z.1 : bool = aten::__is__(%2, %20)
= prim::Return[exit_level="block"]() # empty, added from convert_load_to_stores
= prim::Return[exit_level="block"](%z.1) # added from compiler
block1():
%z.2 : bool = aten::__is__(%3, %4)
= prim::Return[exit_level="block"]() # empty, added from convert_load_to_stores
= prim::Return[exit_level="block"](%z.2) # added from compiler
I'm not sure exactly how to handle this case because you have a block where two exits have a different number of returns, targeting the same block. This breaks invariants in exit_transform.
- This is a smaller thing, but as seen above, this would break invariants invariants in ir.h that prim::Return does not have any nodes preceding it for LoopExits because will have nodes that are emitted after. You could special case LoopExits, maybe its a reason not to merge it with prim::Return.
Neither of these are insurmountable it would just be added complexity to an already complex diff. IMO these changes would make this PR more complicated as a whole, (although you may disagree). I would prefer to unify things in a follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are already ssa form if statements getting transformed at all? Shouldn't the load/store additions just add outputs to the existing if statement outputs? It also seems weird that there can be extra prim::Return nodes. There should only be a single prim return (return_node()) for each block. Nothing should be adding a prim::Return node. The second concern seems related to this point because we shouldn't be adding prim::Return nodes, just editing the existing prim::Return that serves as an output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern with "make the PR more complicated as a whole" is that it would still make the code complicated even if were across two PRs, and we should be looking for ways to handle all cases consistently so we do not incur this complexity. I think the root issue here is that already-ssa and needs to be converted to ssa if statements are not merged together in the ssa conversion pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's currently possible to unify the LoopExit to prim::Return node and to have one prim::Return per block because you need to have both a prim::Return for the block outputs and separately a Return for the LoopExit.
E.g.
if i == 1:
break
a = 1
else:
a = None
We unify a = 1 and a = None to Optional[int]. Erasing the block output in the true branch would erase the original type captured, and set a to None in the resulting scope.
Also, things get a little messy in the transform pass if block outputs can either exit to block or loop. In an if statement like above, where the true block exits to loop and the false block exits to block, when & how do you align the values that are escaping to loop scope and those that are escaping to block scope?
You'd probably first transform LoopExits, then transform BlockExits, but then you're in state after transforming LoopExits where the BlockExits contain both the loop exit values and block exit values, without clearly distinguishing which is which.
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your thorough comments!! will update the PR to address them. Not sure about the analysis / vs transformation comment, so i responded to you there.
| InlineLoopCondition(graph); | ||
| EraseLoadsStores erase_loads_stores; | ||
| erase_loads_stores.run(graph); | ||
| TransformExits(graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What book keeping would you like exactly? I think passing stateful maps around will be more complicated than something that is explicit in the IR & human readable from one pass to the next.
| destroyNodeAfterExit(*iter); | ||
| } | ||
|
|
||
| ExitPair transformExits(Block* block) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that restructuring it form a taking a Block to Node will be all that helpful. When we restructure returns for Graphs & Closures, we have to operate on the block b/c a Graph is not a node.
For the second part: 1) figuring out if the node can break/continue, 2) transforming the rest of the block correspondingly it seems like we'd just be duplicating work without all that much benefit. The current pass already inserts the minimal number of outputs as it is computed, so it's not like calculating ahead of time will really gain us much.
That is: we don't add any outputs if it WONT_EXIT, and we add exactly the outputs we need otherwise (no extra boolean output if it is WILL)
Krovatkin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if you could add a paragraph or two to our documentation that gives a higher-level explanation of how all these transformations work together?
| auto pre_header = n->blocks().at(1); | ||
| auto header_block = n->addBlock(); | ||
| header_block->cloneFrom(pre_header, [](Value* v) { return v; }); | ||
| moveBlockBeforeNode(n, header_block); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: maybe, introduce a function cloneBlockBeforeNode which could rely on moveBlockBeforeNode
|
|
||
| } // namespace | ||
|
|
||
| enum ExitStatus { WILL, MIGHT, WONT }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: yes, no, maybe might be a more conventional choice
| // has been hit or not, and conditionalize further execution. | ||
| // First we remove block exits of if nodes, then we replace Loop Block exits | ||
| // with LoopExits. Then we remove LoopExits. | ||
| // block0(%i.2 : int, %i.12 : int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, Michael and I tried to follow this example and we had to give up. Providing descriptive names or at least the definitions for all used values would be very helpful. Also, showing the same transformation in pseudo python might be helpful as well.
|
My comments are mostly in-line with @ZolotukhinM so Ill let you address those first. |
ZolotukhinM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation looks complicated but often it is because there are many corner cases it needs to cover. These cases are not always obvious, so some sort of document highlighting them would be helpful to accompany this change (also you might want to describe what alternative approaches have been considered and why they've been rejected). I suggest adding such a description and following up on existing comments - probably after that the PR would be easier for review.
…ust emit LoopContinuations
|
Okay I added two commits. The most-recent commit removes block & loop exits as a concept, so that the SSA pass just emits block outputs. The transformer pass now just removes |
ZolotukhinM
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the remarks, I think we're getting close!
I added some more remarks - mostly about adding comments to the code. If you don't see a way how to improve the existing comments, fell free to ignore my remarks, but try to approach them as if you are reading the code from the first time without any context. With that mindset you would probably see what I meant.
There is also a couple of unaddressed old remarks - please follow up on them before landing too.
| enum ExitStatus { WILL, MIGHT, WONT }; | ||
|
|
||
| // hasExited() indicates whether or not an exit has been hit. | ||
| // if hasExited() == true_val_ then we have exited, if == false_val_ we have |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of minor remarks about this comment:
- Imagine a user finds this comment and does not know about the context (i.e. that this struct is used in SSA construction phase). I think it would be helpful to mention what exits we're examining here and what we'd like to do with them.
- The comment mentions what happens if hasExited() is true_val or false_val, but it doesn't tell how these two values should be defined and what happens if hasExited is not a constant (and if it is allowed).
|
|
||
| auto status = getExitStatus(exit_pair); | ||
| // once we transform returns, this will no longer be true | ||
| TORCH_INTERNAL_ASSERT(status == WILL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why doesn't it fail for usual loops with no break/continue? I thought the exit status would be WONT for them - am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented in code
| default: | ||
| break; | ||
| } | ||
| ExitStatus status = getExitStatus(exit_pair); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comment about what we're going to do here might be helpful. Basically, at this point we're done with the recursion and only looking at the current block - stating the rules of transformations we plan to do would help to understand the code.
|
|
||
| // Recurses on the if node and returns its return status | ||
| // If status != WONT_RETURN, sets the block_return_val and has returned val | ||
| // of its parent block before exit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful to add comments about the main goal of the function - namely about how it deals with ifs where one branch exits and the other doesn't (and all other combinations).
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
Add support for breaks and continues in the jit. We do with a Graph transform pre-SSA.
A graph of the form
```
def test():
while i < 5:
if i == 3:
break
i += 1
print(i)
```
has the body of the loop transformed to
```
if i == 3:
did_break = True
else:
did_break = False
if did_break:
loop_exit = True
else:
i += 1
print(i)
loop_exit = i < 5
```
I am going to add more tests but I think it is ready for review now.
Pull Request resolved: pytorch/pytorch#21692
Differential Revision: D16215807
Pulled By: eellison
fbshipit-source-id: 365102f42de4861d9323caaeb39a96de7619a667
Add support for breaks and continues in the jit. We do with a Graph transform pre-SSA.
A graph of the form
has the body of the loop transformed to
I am going to add more tests but I think it is ready for review now.