-
Notifications
You must be signed in to change notification settings - Fork 26.3k
ArgumentStash for int64_t arguments #12939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@apaszke does this look OK? |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks ok, but the code is structured in a very weird way
torch/csrc/jit/tracer.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/tracing_state.h
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/tracer.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
48c8559 to
a467241
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jamesr66a is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary:
Scalars are being traced as constants.
This PR is to fix this issue.
The ONNX Graph for Test_Full_op() before and after this change:
def Test_Full_op():
class Test_Full(nn.Module):
def forward(self, x):
return torch.full((3, 4), x, dtype=torch.long)
model = Test_Full()
x = torch.tensor(12)
output = model(x)
Before this change:
graph(%input1 : Long()):
%output1 : Float(3, 4) = onnx::Constant[value=<Tensor>]
return (%output1)
After this change:
graph(%input1 : Long()):
%1 : int[] = onnx::Constant[value= 3 4 [ Variable[CPULongType]{2} ]]
%2 : Tensor = onnx::ConstantOfShape[value={0}]
%output1 : Float(3, 4) = onnx::Add(%2, %input1)
return (%output1)
Similar PR : #12939
Pull Request resolved: #21931
Reviewed By: zrphercule
Differential Revision: D15950066
Pulled By: houseroad
fbshipit-source-id: 3470665d88fa34faa600940ef16b069a06002cd5
Closes #12906. #12580 is still open because the schema is marked as
traceable=falsein the arg parser constructor, I think.