Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented May 2, 2019

Stack from ghstack:

Summary:
This expose the linear directly when we called the functional interface, the ATen linear op already do the same thing with the functional interface, so there's really no need to duplicate the code. Also, this will expose the higher level aten::linear op up until the custom fusion, so that different backends could know the high level information

Test Plan:
Test the linear op is correctly decomposed in the decomposition pass, which also means it does not get decomposed before that pass.

Currently being blocked from landing by #19769 and #20734.

Differential Revision: D15190354

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: pybind Related to our Python bindings / interactions with other Python libraries labels May 2, 2019
wanchaol added 2 commits May 1, 2019 17:48
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
@wanchaol wanchaol requested review from apaszke, bddppq, bwasti and zdevito May 2, 2019 01:10
wanchaol added 4 commits May 1, 2019 18:46
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
@wanchaol wanchaol changed the title dispatch and expose linear op [jit] dispatch and expose linear op May 2, 2019
wanchaol added 2 commits May 2, 2019 19:12
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
Copy link
Contributor

@bwasti bwasti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally looks good, one concern re shape inference

}
} else if (node->matches("aten::linear(Tensor input, Tensor weight, Tensor? bias) -> Tensor")) {
if (auto type = input_type(0)) {
node->output()->setType(type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is type here a complete tensor or just a dimensioned tensor?

the input shape is not the same as the output shape right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm you are right, actually it should be a complete tensor and the output shape should be different. seems like bilinear below is also wrong, will fix it.

wanchaol added 7 commits May 5, 2019 21:15
…r op"

dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
…atch and expose linear op"

dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any problems here, but because this shuffles around where decompositions happen, I am concerned about test coverage and performance.

  • Can we make sure the derivative formula for linear is actually being testing
  • @apaszke can you take a look at this and see if anything jumps out as bugs? Context is that mkldnn backend has good fused linear performance and we are splitting the op up before it can be seen.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should check the AD formulas
  2. This likely breaks a lot of the fusions we used to do. We want to make sure that the bias additions in LSTM get fused, so the fuser needs to be taught to decompose addmms.

SHAPE_ASSERT(weight_type->sizes().size() == 2 && sizes.size() >= 2);
sizes.at(last_dim) = weight_type->sizes()[0];
node->output()->setType(input_type->withSizes(sizes));
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to add complete shape prop, since we're not really using it at this point. Also, this looks like the same case as for aten::mm, so maybe we could avoid duplicating them?

Copy link
Collaborator Author

@wanchaol wanchaol May 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten::mm section requires the tensor inputs to be 2, where aten::linear have 3.. I think I can also do partial shape prop like the aten::bilinear defined here, but this also requires additional code (which is a bit duplicated)..

Edit: do you mean the complete shape prop of aten::mm defined here? they are slightly different semantics in the formula: linear(input, weight) = input.mm(weight.t())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, let's leave it.

int ndim = input_type->dim();
Value* new_output = nullptr;
if (ndim == 2 && bias->type()->isSubtypeOf(TensorType::get())) {
// if ndim == 2 and bias is statically defined, dispatch to addmm decomposition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

statically defined is not very well defined here. You're only checking that we can refine its type, which doesn't tell you much.

Copy link
Collaborator Author

@wanchaol wanchaol May 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something I steal from the fuser code, where it tells you if it is statically defined (statically undefined needs more check though). because we are only check the input value types, I guess it would be sufficient to distinguish it.

The next step after this PR, I will move the batchnorm and layernorm decomposition to here as well, then I can use the util function defined above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that the term statically defined is not used correctly. I'm not even 100% sure what does that mean

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my point of view, statically defined tensor argument means that the value passed in is neither a None constant nor a Optional[Tensor] type, which we could statically know the type passed in rather than a dynamically unknown optional type. So isSubtypeOf(TensorType) should guarantee this definition

@wanchaol
Copy link
Collaborator Author

wanchaol commented May 7, 2019

Thanks @zdevito @apaszke ! re concerns on ad formula and performance:

  1. the ad formula is already been checked here, it's checking the input.dim() == 2 pathway, I will add one more test case to cover the matmul AD pathway
  2. I believe by putting the decomposition pass before fuser and the PR added the decomposition from linear to addmm, it should preserve the old behavior as we did the bias add fusion, I checked the forward and backward graph for lstm, it remains the same, so it should be good on fusion and not regressed the LSTM performance. I will make sure to run enough benchmarks before landing this.

wanchaol added 4 commits May 7, 2019 19:25
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
…it] dispatch and expose linear op"

dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
@wanchaol
Copy link
Collaborator Author

wanchaol commented May 8, 2019

I checked that the LSTM graph is the same before and after this PR, except the unique names of values are different:

https://gist.github.com/wanchaol/bcd988809a192c56f55319fc6a305637/revisions

Performance: shows no noticeable difference between master and master + this PR:

On Master:

before

Master + this stacked PR:

after

So I think this PR is not regressed our current fusion :)

dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
wanchaol added 2 commits May 8, 2019 13:42
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
@apaszke
Copy link
Contributor

apaszke commented May 9, 2019

Ok I see you have made changes to how the ops are decomposed, so I'd need to read the prior patches to make sure this works ok. If you double checked that we're not loosing LSTM fusion + mm batching, then it might be good to go. I'll try to catch up with the whole stack soon.

dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
@pytorchbot pytorchbot added the module: autograd Related to torch.autograd, and the autograd engine in general label Jun 4, 2019
dispatch and expose linear op

gh-metadata: pytorch pytorch 20039 gh/wanchaol/4/head
@XiaobingSuper
Copy link
Collaborator

@wanchaol, Do you have any progress about this PR?

@wanchaol
Copy link
Collaborator Author

@XiaobingSuper this is currently blocked from landing by issue #19769 and #20734, we don't have a good solution to those currently, so this might not be land until we could solve them.

@XiaobingSuper
Copy link
Collaborator

@wanchaol , #19769 seems fixed, do you have any plan to do it?

@wanchaol
Copy link
Collaborator Author

@wanchaol , #19769 seems fixed, do you have any plan to do it?

@XiaobingSuper sorry the issue was closed to dedup, but the underlying problem is not fixed yet, which is #20734.

But I think we recently figured out a plan on how to fix it and I will let you know about the progress once we start working on it :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: autograd Related to torch.autograd, and the autograd engine in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants