Skip to content

Conversation

@eellison
Copy link
Contributor

Add support for breaks and continues in the jit. We do with a Graph transform pre-SSA.

A graph of the form

def test():
    while i < 5:
        if i == 3:
            break
        i += 1
        print(i)

has the body of the loop transformed to

if i == 3:
    did_break = True
else:
    did_break = False
if did_break:
    loop_exit = True
else:
    i += 1
    print(i)
    loop_exit = i < 5

I am going to add more tests but I think it is ready for review now.

@eellison eellison requested review from suo and zdevito June 12, 2019 18:14
@pytorchbot pytorchbot added caffe2 oncall: jit Add this issue/PR to JIT oncall triage queue module: build Build system issues module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries labels Jun 12, 2019
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't finish the break_continue_transform.cpp but I won't be free for a bit so here is what I have so far.


for (auto& v : true_vars->definedVariables()) {
if (false_vars->findInAnyFrame(v)) {
if (false_vars->findInAnyFrame(v) || false_escape) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is inserting extra uninitialized nodes/loop carried nodes for variables used in blocks edited by the break/continue pass that are not otherwise used outside the block.

Copy link
Contributor Author

@eellison eellison Jun 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but we run DCE, and this is the case generally with emitting if outputs that they may not be used otherwise. I don't think it can be emitting extra loop carried nodes because in this case the variable is not defined outside of the loop

LoopTransformer(std::shared_ptr<Graph> graph_, Transform transform_)
: graph(std::move(graph_)) {
WithInsertPoint guard(graph->block()->nodes().front());
true_val = graph->insertConstant(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use _ for class members.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am having trouble following the logic in the pass. Part of the problem is that it isn't following a consistent representation for what a prim::Loop is. This pass starts with the loop having a separate guard block, and then ends without one, but also having modified the block. I'd expect instead that it works on the Loops pre-lowering and let ssa conversion handle the lowering.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this was a mistake, it's accessing maxTripCount

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest moving the cleanup/refactoring stuff into a separate PR.

@eellison eellison force-pushed the transform_breaks_continues branch from 1cda8e7 to 73cea49 Compare June 14, 2019 22:06
@eellison eellison requested review from zdevito June 17, 2019 15:19
@suo suo removed their request for review June 17, 2019 20:16
@eellison eellison requested review from Krovatkin and ZolotukhinM and removed request for Krovatkin June 19, 2019 21:18
@eellison
Copy link
Contributor Author

failures are unrelated...

Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really big PR :) I tried to walk through it and I think eventually I understood most of it, however, I really think we need to try to simplify it. I'm quite terrified to think of what would happen if there is a bug in this code and you are not around to debug it.

Some more specific suggestions/remarks from me are inline.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest moving the cleanup/refactoring stuff into a separate PR.

InlineLoopCondition(graph);
EraseLoadsStores erase_loads_stores;
erase_loads_stores.run(graph);
TransformExits(graph);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably have discussed this with @zdevito, but are we sure that we want to introduce BlockExit and LoopExit as a primary IR nodes? It looks like they are just used as a metadata to pass some info between these passes - could we just perform a book-keeping in some map and pass it instead? IR transformations to keep these nodes are responsibly for a bulk part of complexity in this PR in my opinion.

NB: Load and Store are in somewhat in similar position, but I think they are fine because 1) they have much clearer semantics, 2) they don't need such sophisticated IR manipulations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What book keeping would you like exactly? I think passing stateful maps around will be more complicated than something that is explicit in the IR & human readable from one pass to the next.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The book-keeping I'm referring to is basically arguments of BlockExit/LoopExit that you encode in IR. The problem with the current approach is that even though it's supposed to be human readable, it really is not. I will have hard time understanding what a specific BlockExit means if I see it in IR. Does it depend on the context? What do arguments mean? What does it actually do (is it a jump?)?

Having stateful maps is generally bad, but if this is just between a couple of passes that are already dependent on each other and always run together, that might be not the worst.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlockExit jumps to the most recent block, LoopExit jumps to the most recent loop block. I'm happy to consider other options that would be more clear. Maybe other there other names you think would be better ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just set the exit behavior as an attribute on the prim::Return node that already exists at the end of every block. It seems redundant to have both outputs of the block, and separate node temporarily the block to express those outputs. The attribute can just be a string specifying one of attr::exit_level = {"block", "loop", "function"}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ATM, there needs to be a distinction between the exits that are added in convert_to_ssa.cpp and the existing prim::Return for a couple reasons.

  1. We emit ifs with block outputs already added in compiler.cpp because of and and or. After convert_to_ssa and before exit_transforms is run it would look like.
  %z : bool = prim::If(%5)
    block0():
      %z.1 : bool = aten::__is__(%2, %20)
      = prim::Return[exit_level="block"]() # empty, added from convert_load_to_stores
      = prim::Return[exit_level="block"](%z.1) # added from compiler
    block1():
      %z.2 : bool = aten::__is__(%3, %4)
      = prim::Return[exit_level="block"]() # empty, added from convert_load_to_stores
      = prim::Return[exit_level="block"](%z.2) # added from compiler

I'm not sure exactly how to handle this case because you have a block where two exits have a different number of returns, targeting the same block. This breaks invariants in exit_transform.

  1. This is a smaller thing, but as seen above, this would break invariants invariants in ir.h that prim::Return does not have any nodes preceding it for LoopExits because will have nodes that are emitted after. You could special case LoopExits, maybe its a reason not to merge it with prim::Return.

Neither of these are insurmountable it would just be added complexity to an already complex diff. IMO these changes would make this PR more complicated as a whole, (although you may disagree). I would prefer to unify things in a follow up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are already ssa form if statements getting transformed at all? Shouldn't the load/store additions just add outputs to the existing if statement outputs? It also seems weird that there can be extra prim::Return nodes. There should only be a single prim return (return_node()) for each block. Nothing should be adding a prim::Return node. The second concern seems related to this point because we shouldn't be adding prim::Return nodes, just editing the existing prim::Return that serves as an output.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concern with "make the PR more complicated as a whole" is that it would still make the code complicated even if were across two PRs, and we should be looking for ways to handle all cases consistently so we do not incur this complexity. I think the root issue here is that already-ssa and needs to be converted to ssa if statements are not merged together in the ssa conversion pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's currently possible to unify the LoopExit to prim::Return node and to have one prim::Return per block because you need to have both a prim::Return for the block outputs and separately a Return for the LoopExit.
E.g.

if i == 1:
    break 
    a = 1
else:
    a = None

We unify a = 1 and a = None to Optional[int]. Erasing the block output in the true branch would erase the original type captured, and set a to None in the resulting scope.

Also, things get a little messy in the transform pass if block outputs can either exit to block or loop. In an if statement like above, where the true block exits to loop and the false block exits to block, when & how do you align the values that are escaping to loop scope and those that are escaping to block scope?

You'd probably first transform LoopExits, then transform BlockExits, but then you're in state after transforming LoopExits where the BlockExits contain both the loop exit values and block exit values, without clearly distinguishing which is which.

Copy link
Contributor Author

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your thorough comments!! will update the PR to address them. Not sure about the analysis / vs transformation comment, so i responded to you there.

InlineLoopCondition(graph);
EraseLoadsStores erase_loads_stores;
erase_loads_stores.run(graph);
TransformExits(graph);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What book keeping would you like exactly? I think passing stateful maps around will be more complicated than something that is explicit in the IR & human readable from one pass to the next.

destroyNodeAfterExit(*iter);
}

ExitPair transformExits(Block* block) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that restructuring it form a taking a Block to Node will be all that helpful. When we restructure returns for Graphs & Closures, we have to operate on the block b/c a Graph is not a node.

For the second part: 1) figuring out if the node can break/continue, 2) transforming the rest of the block correspondingly it seems like we'd just be duplicating work without all that much benefit. The current pass already inserts the minimal number of outputs as it is computed, so it's not like calculating ahead of time will really gain us much.
That is: we don't add any outputs if it WONT_EXIT, and we add exactly the outputs we need otherwise (no extra boolean output if it is WILL)

Copy link
Contributor

@Krovatkin Krovatkin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if you could add a paragraph or two to our documentation that gives a higher-level explanation of how all these transformations work together?

auto pre_header = n->blocks().at(1);
auto header_block = n->addBlock();
header_block->cloneFrom(pre_header, [](Value* v) { return v; });
moveBlockBeforeNode(n, header_block);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: maybe, introduce a function cloneBlockBeforeNode which could rely on moveBlockBeforeNode


} // namespace

enum ExitStatus { WILL, MIGHT, WONT };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: yes, no, maybe might be a more conventional choice

// has been hit or not, and conditionalize further execution.
// First we remove block exits of if nodes, then we replace Loop Block exits
// with LoopExits. Then we remove LoopExits.
// block0(%i.2 : int, %i.12 : int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, Michael and I tried to follow this example and we had to give up. Providing descriptive names or at least the definitions for all used values would be very helpful. Also, showing the same transformation in pseudo python might be helpful as well.

@zdevito
Copy link
Contributor

zdevito commented Jun 21, 2019

My comments are mostly in-line with @ZolotukhinM so Ill let you address those first.

Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation looks complicated but often it is because there are many corner cases it needs to cover. These cases are not always obvious, so some sort of document highlighting them would be helpful to accompany this change (also you might want to describe what alternative approaches have been considered and why they've been rejected). I suggest adding such a description and following up on existing comments - probably after that the PR would be easier for review.

@eellison
Copy link
Contributor Author

eellison commented Jul 9, 2019

Okay I added two commits.
The second-most recent addresses the cleanups @ZolotukhinM had in his review.

The most-recent commit removes block & loop exits as a concept, so that the SSA pass just emits block outputs. The transformer pass now just removes LoopContinuations which as a name map much more closely to the equivalent python concept, continue. I talked with @ZolotukhinM about these changes but it's possible some of what we talked about got lost in translation to this diff.

@eellison eellison requested a review from ZolotukhinM July 10, 2019 17:45
Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the remarks, I think we're getting close!

I added some more remarks - mostly about adding comments to the code. If you don't see a way how to improve the existing comments, fell free to ignore my remarks, but try to approach them as if you are reading the code from the first time without any context. With that mindset you would probably see what I meant.

There is also a couple of unaddressed old remarks - please follow up on them before landing too.

enum ExitStatus { WILL, MIGHT, WONT };

// hasExited() indicates whether or not an exit has been hit.
// if hasExited() == true_val_ then we have exited, if == false_val_ we have

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of minor remarks about this comment:

  1. Imagine a user finds this comment and does not know about the context (i.e. that this struct is used in SSA construction phase). I think it would be helpful to mention what exits we're examining here and what we'd like to do with them.
  2. The comment mentions what happens if hasExited() is true_val or false_val, but it doesn't tell how these two values should be defined and what happens if hasExited is not a constant (and if it is allowed).


auto status = getExitStatus(exit_pair);
// once we transform returns, this will no longer be true
TORCH_INTERNAL_ASSERT(status == WILL);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't it fail for usual loops with no break/continue? I thought the exit status would be WONT for them - am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented in code

default:
break;
}
ExitStatus status = getExitStatus(exit_pair);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comment about what we're going to do here might be helpful. Basically, at this point we're done with the recursion and only looking at the current block - stating the rules of transformations we plan to do would help to understand the code.


// Recurses on the if node and returns its return status
// If status != WONT_RETURN, sets the block_return_val and has returned val
// of its parent block before exit

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to add comments about the main goal of the function - namely about how it deals with ifs where one branch exits and the other doesn't (and all other combinations).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in cf2889a.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Jul 12, 2019
Summary:
Add support for breaks and continues in the jit. We do with a Graph transform pre-SSA.

A graph of the form
```
def test():
    while i < 5:
        if i == 3:
            break
        i += 1
        print(i)
```
has the body of the loop transformed to
```
if i == 3:
    did_break = True
else:
    did_break = False
if did_break:
    loop_exit = True
else:
    i += 1
    print(i)
    loop_exit = i < 5
```

I am going to add more tests but I think it is ready for review now.
Pull Request resolved: pytorch/pytorch#21692

Differential Revision: D16215807

Pulled By: eellison

fbshipit-source-id: 365102f42de4861d9323caaeb39a96de7619a667
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

caffe2 Merged module: build Build system issues module: internals Related to internal abstractions in c10 and ATen module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants