-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] dispatch and expose linear op #20039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
bwasti
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally looks good, one concern re shape inference
| } | ||
| } else if (node->matches("aten::linear(Tensor input, Tensor weight, Tensor? bias) -> Tensor")) { | ||
| if (auto type = input_type(0)) { | ||
| node->output()->setType(type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is type here a complete tensor or just a dimensioned tensor?
the input shape is not the same as the output shape right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm you are right, actually it should be a complete tensor and the output shape should be different. seems like bilinear below is also wrong, will fix it.
…r op" dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
…atch and expose linear op" dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see any problems here, but because this shuffles around where decompositions happen, I am concerned about test coverage and performance.
- Can we make sure the derivative formula for linear is actually being testing
- @apaszke can you take a look at this and see if anything jumps out as bugs? Context is that mkldnn backend has good fused linear performance and we are splitting the op up before it can be seen.
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We should check the AD formulas
- This likely breaks a lot of the fusions we used to do. We want to make sure that the bias additions in LSTM get fused, so the fuser needs to be taught to decompose
addmms.
| SHAPE_ASSERT(weight_type->sizes().size() == 2 && sizes.size() >= 2); | ||
| sizes.at(last_dim) = weight_type->sizes()[0]; | ||
| node->output()->setType(input_type->withSizes(sizes)); | ||
| return true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to add complete shape prop, since we're not really using it at this point. Also, this looks like the same case as for aten::mm, so maybe we could avoid duplicating them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aten::mm section requires the tensor inputs to be 2, where aten::linear have 3.. I think I can also do partial shape prop like the aten::bilinear defined here, but this also requires additional code (which is a bit duplicated)..
Edit: do you mean the complete shape prop of aten::mm defined here? they are slightly different semantics in the formula: linear(input, weight) = input.mm(weight.t())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, let's leave it.
| int ndim = input_type->dim(); | ||
| Value* new_output = nullptr; | ||
| if (ndim == 2 && bias->type()->isSubtypeOf(TensorType::get())) { | ||
| // if ndim == 2 and bias is statically defined, dispatch to addmm decomposition |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
statically defined is not very well defined here. You're only checking that we can refine its type, which doesn't tell you much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is something I steal from the fuser code, where it tells you if it is statically defined (statically undefined needs more check though). because we are only check the input value types, I guess it would be sufficient to distinguish it.
The next step after this PR, I will move the batchnorm and layernorm decomposition to here as well, then I can use the util function defined above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point is that the term statically defined is not used correctly. I'm not even 100% sure what does that mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my point of view, statically defined tensor argument means that the value passed in is neither a None constant nor a Optional[Tensor] type, which we could statically know the type passed in rather than a dynamically unknown optional type. So isSubtypeOf(TensorType) should guarantee this definition
|
Thanks @zdevito @apaszke ! re concerns on ad formula and performance:
|
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
…it] dispatch and expose linear op" dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
|
I checked that the LSTM graph is the same before and after this PR, except the unique names of values are different: Performance: shows no noticeable difference between master and master + this PR: On Master: Master + this stacked PR: So I think this PR is not regressed our current fusion :) |
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
|
Ok I see you have made changes to how the ops are decomposed, so I'd need to read the prior patches to make sure this works ok. If you double checked that we're not loosing LSTM fusion + mm batching, then it might be good to go. I'll try to catch up with the whole stack soon. |
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
|
@wanchaol, Do you have any progress about this PR? |
|
@XiaobingSuper this is currently blocked from landing by issue #19769 and #20734, we don't have a good solution to those currently, so this might not be land until we could solve them. |
@XiaobingSuper sorry the issue was closed to dedup, but the underlying problem is not fixed yet, which is #20734. But I think we recently figured out a plan on how to fix it and I will let you know about the progress once we start working on it :) |


Stack from ghstack:
Summary:
This expose the linear directly when we called the functional interface, the ATen linear op already do the same thing with the functional interface, so there's really no need to duplicate the code. Also, this will expose the higher level
aten::linearop up until the custom fusion, so that different backends could know the high level informationTest Plan:
Test the linear op is correctly decomposed in the decomposition pass, which also means it does not get decomposed before that pass.
Currently being blocked from landing by #19769 and #20734.
Differential Revision: D15190354