Skip to content

Conversation

@neginraoof
Copy link
Contributor

Exporting meshgrid op in opset 9 symbolics

@neginraoof neginraoof requested a review from apaszke as a code owner September 11, 2019 18:53
@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: onnx Related to torch.onnx labels Sep 11, 2019
Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some inline comments

for i, t in enumerate(tensors):
if t.isCompleteTensor():
shape_i = [g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))] * len(tensors)
shape_i[i] = g.op("Shape", t)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider reusing tensors_shape above..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not tensors shape. This is [1,1,..,n,1,...]

@BowenBao
Copy link
Collaborator

test if failing for opset7. Expand was added in opset8, please add meshgrid to the blacklist in symbolic_opset7.py.

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, one last inline comment. Also please update expect file for test_operators

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@izdeby has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@houseroad
Copy link
Member

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure CIs are green.

@neginraoof
Copy link
Contributor Author

@houseroad please review updates.

@neginraoof
Copy link
Contributor Author

cc @houseroad for review

@neginraoof
Copy link
Contributor Author

@pytorchbot rebase this please

1 similar comment
@neginraoof
Copy link
Contributor Author

@pytorchbot rebase this please

@neginraoof
Copy link
Contributor Author

@pytorchbot retest this please

@neginraoof
Copy link
Contributor Author

@pytorchbot rebase this please

@neginraoof
Copy link
Contributor Author

@houseroad please review the updates

@houseroad
Copy link
Member

The test is still failing.

@neginraoof
Copy link
Contributor Author

@houseroad Thanks. The CI was green before I rebased. The failing test is:
caffe2/python/data_parallel_model_test.py::RecurrentNetworkParallelTest::test_equiv_recurrent
Which I'm not sure if it's related. I'll look into that.

@neginraoof
Copy link
Contributor Author

@pytorchbot retest this please

@neginraoof
Copy link
Contributor Author

@houseroad Tests have passed.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@houseroad merged this pull request in 60d6060.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: onnx Related to torch.onnx oncall: jit Add this issue/PR to JIT oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants