Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Sep 19, 2018

Previously, aten::view returned a Dynamic type when attr::size is a prim::ListConstruct.
See this for a repro.
This prevented a pre-multipled lstm input graph from being fusible (aten::view is necessary
to do premultiplication).

If aten::view is passed an output of a prim::ListConstruct node, then shape prop should
be able to figure out its TensorType because we statically know the number of inputs to
prim::ListConstruct. This PR implements that.

Previously, aten::view returned a Dynamic type when attr::size is a prim::ListConstruct.
See [this for a repro](https://gist.github.com/zou3519/cbd610472ba3369f556fa612a7d93b28).
This prevented a pre-multipled lstm input graph from being fusible (aten::view is necessary
to do premultiplication).

If aten::view is passed an output of a prim::ListConstruct node, then shape prop should
be able to figure out its TensorType because we statically know the number of inputs to
prim::ListConstruct. This PR implements that.
@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 19, 2018
auto input_node = node->namedInput(shape_input)->node();
if (input_node->kind() == prim::ListConstruct) {
return tensor_types.at(0)->withDim(input_node->inputs().size());
}

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_jit.py Outdated
x = torch.randn(3, 1, 5, requires_grad=True)
graph = torch.jit.script(fn).graph
torch._C._jit_pass_complete_shape_analysis(graph, (x,), False)
self.assertExpectedGraph(graph)

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

test/test_jit.py Outdated

x = torch.randn(3, 1, 5, requires_grad=True)
graph = torch.jit.script(fn).graph
torch._C._jit_pass_complete_shape_analysis(graph, (x,), False)

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519 zou3519 deleted the view-shapeprop branch September 25, 2018 14:13
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants