-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] fix dce over loops #22632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jit] fix dce over loops #22632
Conversation
We need re-run the marking pass over loop sub-blocks until they converge to a fixed point. Thanks to @Chillee for catching this bug!
Chillee
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise looks pretty straightforward and correct to me.
| auto node = *it; | ||
| for (auto subBlock : node->blocks()) { | ||
| mark(subBlock); | ||
| if (node->kind() == prim::Loop) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check for c10::onnx::loop here like we do in the other place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
onnx is purely functional and has no aliasing, so we wouldn't need the special behavior
Krovatkin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
![]()
| // `b` are dead, even though `b[0] += 1` mutates a live memory location (since | ||
| // `b[0]` is an alias of `a`). | ||
| // | ||
| // We need to mark the loop again with the information that `a` is live, and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// We need to mark the loop again with the information that
ais live, and
we could also add
i.e. `a` is used to compute `tot` in the next iteration
| bool marked = mark(node->blocks().at(0)); | ||
| innerMarked = marked; | ||
| anyMarked |= marked; | ||
| } while (innerMarked); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool marked = false;
// Did we ever mark anything new?
bool anyMarked = false;
do {
marked = mark(node->blocks().at(0));
anyMarked |= marked;
} while (marked);| // return block. We consider all graph outputs to be "used", so just mark | ||
| // this node normally. | ||
| return mark(node); | ||
| mark(node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite like the fact that markReturnNode isn't changed to bool. I get that we only call mark for the top-level (function return) and in this case we aren't in a loop, so it doesn't matter, and if we are in a loop we make live only body outputs. When we'll get early returns we might want to consider revisiting this and make markReturnNode to actually return a bool
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine.
[jit] fix dce over loops We need re-run the marking pass over loop sub-blocks until they converge to a fixed point. Thanks to @Chillee for catching this bug! gh-metadata: pytorch pytorch 22632 gh/suo/79/head
Stack from ghstack:
We need re-run the marking pass over loop sub-blocks until they converge
to a fixed point. Thanks to @Chillee for catching this bug!
Differential Revision: D16184469