-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Shape prop view/reshape/as_strided through prim::ListConstructs #11877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Previously, aten::view returned a Dynamic type when attr::size is a prim::ListConstruct. See [this for a repro](https://gist.github.com/zou3519/cbd610472ba3369f556fa612a7d93b28). This prevented a pre-multipled lstm input graph from being fusible (aten::view is necessary to do premultiplication). If aten::view is passed an output of a prim::ListConstruct node, then shape prop should be able to figure out its TensorType because we statically know the number of inputs to prim::ListConstruct. This PR implements that.
| auto input_node = node->namedInput(shape_input)->node(); | ||
| if (input_node->kind() == prim::ListConstruct) { | ||
| return tensor_types.at(0)->withDim(input_node->inputs().size()); | ||
| } |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_jit.py
Outdated
| x = torch.randn(3, 1, 5, requires_grad=True) | ||
| graph = torch.jit.script(fn).graph | ||
| torch._C._jit_pass_complete_shape_analysis(graph, (x,), False) | ||
| self.assertExpectedGraph(graph) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
… test with better test
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
test/test_jit.py
Outdated
|
|
||
| x = torch.randn(3, 1, 5, requires_grad=True) | ||
| graph = torch.jit.script(fn).graph | ||
| torch._C._jit_pass_complete_shape_analysis(graph, (x,), False) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Previously, aten::view returned a Dynamic type when attr::size is a prim::ListConstruct.
See this for a repro.
This prevented a pre-multipled lstm input graph from being fusible (aten::view is necessary
to do premultiplication).
If aten::view is passed an output of a prim::ListConstruct node, then shape prop should
be able to figure out its TensorType because we statically know the number of inputs to
prim::ListConstruct. This PR implements that.