-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Opset 11 updates #28225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Opset 11 updates #28225
Conversation
…oof/pad # Conflicts: # torch/onnx/symbolic_helper.py # torch/onnx/symbolic_opset11.py
|
Looking into fixing test_onnx_opset test_topk |
torch/onnx/symbolic_helper.py
Outdated
| _unimplemented("TopK", "Out parameter is not supported") | ||
| if not _is_value(k): | ||
| k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64)) | ||
| k = g.op("Reshape", k, g.op("Constant", value_t=torch.tensor([1]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Reshape can be put in an else branch. In the if branch construct the Constant with torch.tensor([k], dtype=torch.int64).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That won't work for opset 10, if k is value. I updated this with a minor change.
Co-Authored-By: Bowen Bao <[email protected]>
Co-Authored-By: Bowen Bao <[email protected]>
Co-Authored-By: Bowen Bao <[email protected]>
|
@pytorchbot retest this please |
spandantiwari
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Added one comment that needs attention.
torch/onnx/symbolic_helper.py
Outdated
| mode_s='constant', | ||
| value_f=0.) | ||
| else: | ||
| input = g.op("Pad", input, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this pattern of capturing branched implementations of different opsets in a single helper function in symbolic_helper.py. I think it is better to have symbolic implementations of different opsets in their own files, e.g. in this case symbolic_opset10.py and symbolic_opset11.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Spandan. I agree with you about scenarios like this with different branches for opset versions.
In this case, I want to avoid adding the whole symbolic function for avg_pool in opset 11 file, since the diff between opset 10 and 11 symbolic functions is this line.
|
cc @houseroad please review |
1 similar comment
|
cc @houseroad please review |
torch/onnx/symbolic_opset11.py
Outdated
| extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)), | ||
| g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len) | ||
| # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] | ||
| paddings = g.op("Concat", pad, g.op("ConstantOfShape", extension, value_t=torch.tensor([0])), axis_i=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get this error from this line:
Type parameter (T) bound to different types (tensor(int32) and tensor(int64) in node (__Concat_213).
Changed my code to pass pad.long() to nn.functional.pad instead of pad.int() to work around, maybe the code here should convert types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX only supports int64 for now. I added the cast, but you cannot have int pads in ONNX for now.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@houseroad Could you please let me know if there are any related failures? |
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@houseroad merged this pull request in ebc216a. |
This PR contains:
1- pad updates for opset11 symbolic
2- Updated avg_pool for opset11
3- TopK updates for opset 11