-
Notifications
You must be signed in to change notification settings - Fork 26.3k
dont clone symints, dont clobber symint proxies #88230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88230
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 7c0d484: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks :)
[ghstack-poisoned]
|
As discussed in meeting, a further BE follow up is to delete clone entirely |
[ghstack-poisoned]
|
did anyone ever explain why this is better (or what's wrong with the clone approach)? Hopefully not just that 'its easier to automatically create missing symints/proxies than to find out where they should have been created'? |
[ghstack-poisoned]
|
I have a hypothesis for why the cloning approach is wrong. It is only safe to clone when we are setting the sizes/strides on the bottom-most fake tensor; only those tensors are guaranteed to pass through proxy tensor mode and actually get tracked. With functionalization, we also have a FunctionalizeTensorWrapper floating around, and I guess we are accidentally cloning on that (we shouldn't), resulting in untracked proxies. The cloning business seems pretty delicate, and in practice this approach seems to have been more robust; we haven't had any more partitioner disasters on the branch since we put this in to soak. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
|
Status update on this?! |
|
This is mostly ready, but there are a few proxy tensor tests failures that I didn't have a chance to chug through. I try to hit them today. Also needed to rebase / rebuild due to merge conflicts |
|
Darn - the breakage is that there are cases where we DO want to clobber the proxy, for inplace ops on tensors. If I do But there are probably also other cases, where an inplace op mutates the size/stride/storage_offset and we need to replace the proxy. Since mutations of symints isn't really a thing, maybe we can avoid the clobbering only in the case where we have SymInts/SymFloats, and properly do the clobbering for tensors? |
|
Sgtm |
|
BTW, it doesn't have to be a full clobber for tensors either. When we |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
| # TODO(whc)- are the saved-tensors/saved-symints correct here? | ||
| # i just made the test pass based on what default partition did | ||
| [False, True, True, False, False] + [False] * 5 + [True] * 3, | ||
| [False, True, True, False, False] + [False] * 4 + [True] * 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @wconstab - I guess I'm not too surprised that this test wobbled, since it seems like this test originally wobbled when we added symint clone'ing. Just wanted to see if you have any thoughts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm trying to recall..
[False, True, True, False, False]
this first block should correspond to the outputs of f, namelyreturn cat, sb, c, mm2, where sb expands to 2 dims of size, hence True, True for is_sym_node
- [False] * 4 + [True] * 4,
this part i can't remember exactly. in the non-clone case we're saving one less non-sym-node for backward, and saving one more sym-node. I think what i would do is check what these forward/backward graphs look like and what they were like before this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm ok. The new graph does seem strictly better than the old one. To (partially) explain the difference of why we have -1 tensor and +1 symint saved:
- One thing I see is that we return
cattwice in the old forward but not the new one. - We also previously weren't saving one of the symints that was created in the forward, and now we are:
add: Sym(s1 + s2) = sym_size + sym_size_1
Old:
class GraphModule(torch.nn.Module):
def forward(self, primals_1: f32[s0], primals_2: f32[s1, s0], primals_3: f32[s2, s0], primals_4: f32[s0, 1]):
# No stacktrace found for following nodes
sym_size: Sym(s1) = torch.ops.aten.sym_size(primals_2, 0)
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(primals_3, 0)
add: Sym(s1 + s2) = sym_size + sym_size_1
sym_size_2: Sym(s0) = torch.ops.aten.sym_size(primals_1, 0)
expand: f32[s1 + s2, s0] = torch.ops.aten.expand.default(primals_1, [add, sym_size_2]); add = None
cat: f32[2*s1 + 2*s2, s0] = torch.ops.aten.cat.default([expand, primals_2, primals_3])
mm: f32[2*s1 + 2*s2, 1] = torch.ops.aten.mm.default(cat, primals_4)
sym_size_3: Sym(1) = torch.ops.aten.sym_size(mm, 1)
view: f32[1, s0] = torch.ops.aten.view.default(primals_1, [sym_size_3, sym_size_2]); primals_1 = sym_size_3 = None
mm_1: f32[2*s1 + 2*s2, s0] = torch.ops.aten.mm.default(mm, view)
sym_size_5: Sym(s0) = torch.ops.aten.sym_size(primals_2, 1); primals_2 = None
return [cat, sym_size, sym_size_5, primals_3, mm_1, view, mm, primals_4, expand, cat, sym_size_2, sym_size, sym_size_1]
new:
class GraphModule(torch.nn.Module):
def forward(self, primals_1: f32[s0], primals_2: f32[s1, s0], primals_3: f32[s2, s0], primals_4: f32[s0, 1]):
# No stacktrace found for following nodes
sym_size: Sym(s1) = torch.ops.aten.sym_size(primals_2, 0)
sym_size_1: Sym(s2) = torch.ops.aten.sym_size(primals_3, 0)
add: Sym(s1 + s2) = sym_size + sym_size_1
sym_size_2: Sym(s0) = torch.ops.aten.sym_size(primals_1, 0)
expand: f32[s1 + s2, s0] = torch.ops.aten.expand.default(primals_1, [add, sym_size_2])
cat: f32[2*s1 + 2*s2, s0] = torch.ops.aten.cat.default([expand, primals_2, primals_3]); expand = None
mm: f32[2*s1 + 2*s2, 1] = torch.ops.aten.mm.default(cat, primals_4)
sym_size_3: Sym(1) = torch.ops.aten.sym_size(primals_4, 1)
view: f32[1, s0] = torch.ops.aten.view.default(primals_1, [sym_size_3, sym_size_2]); primals_1 = sym_size_3 = None
mm_1: f32[2*s1 + 2*s2, s0] = torch.ops.aten.mm.default(mm, view)
sym_size_4: Sym(s0) = torch.ops.aten.sym_size(primals_2, 1); primals_2 = None
return [cat, sym_size, sym_size_4, primals_3, mm_1, view, mm, cat, primals_4, sym_size_1, sym_size, add, sym_size_2]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, aren't you returning cat twice in both old/new?
|
|
||
| def clone(self): | ||
| return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant) | ||
| return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we want to return self here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're disabling the SymNode cloning. A better refactor would be to eliminate the clone method entirely
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. It might be nice to add a warning here since this behavior seems a little unexpected
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Until Brian removes calls to clone from C++, this method will get (uselessly) called a bunch, warning would be spamming
[ghstack-poisoned]
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / cuda11.6-py3.10-gcc7-sm86 / test (default, 1, 4, linux.g5.4xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "flaky ci only" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):