-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Refactor pybind_utils.h #21550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This readies pybind_utils so we can have all our type-inferring stuff in 1 place
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, have a few comments
| // not create a CompleteTensorType | ||
| return MatchTypeReturn(DimensionedTensorType::create(tensor)); | ||
| } | ||
| return MatchTypeReturn(CompleteTensorType::create(tensor)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test case where the element / list element is a tensor type, and/or make sure we are not overspecializing here? When we try to use specialized tensor types it messes up things in a lot of places.
e.g.
self.list = [Float(*, *)]
dim_tensor = self.list[0]
dim_tensor = torch.tensor()
Here, torch.tensor is not a subtype of Float(*, *) so it would throw. You might need this for the tracer, but for the script type inference a call to unshapedType() should fix the problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should do that in a follow up, this is just a refactor and behavior shouldn't change here
| element_types.push_back(*type_match.type); | ||
| } else { | ||
| // Forward error message along | ||
| return type_match.errMsg; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just return the error here ? ^^ relating to my other comment about nesting the error message
| size_t len = py::len(list); | ||
| if (!len) { | ||
| AT_ERROR("List trace inputs must have elements"); | ||
| return MatchTypeReturn("List trace inputs must have elements"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trace is too specific of an error message now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? This is only for lists
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It says List trace inputs
|
What is the reason for putting the type inference logic on python objects and not ivalues ? |
|
The tracer doesn't have |
|
is this ready for re-review? |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, a few more comments. I think you need to check the tracer inputs for non-Tensors types before landing, also you have a failing test:
what_without_backtrace()) .find("forward() expected a value of type 'Tensor' " "for argument 'input' but instead found type 'int'") == 0 INTERNAL ASSERT FAILED at /var/lib/jenkins/workspace/test/custom_operator/test_custom_ops.cpp:85
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
| } | ||
| auto type = *match.type; | ||
|
|
||
| if (isTraceableType(type)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure exactly, but I wonder if we can get away with not checking the types here since the values are checked at
pytorch/torch/csrc/jit/tracer.cpp
Line 215 in 8c57ce8
| static IValue addInput(const std::shared_ptr<TracingState> & state, const IValue& input, const TypePtr& type, Value* value) { |
Still, I think it is fine in this PR and we can look into removing as a follow up.
Summary: This refactors pybind_utils so we can have all our type-inferring stuff in 1 place (e.g. for #21379) There is some follow up work to make the error messages better, but I think that's fine to save for another PR. ](https://our.intern.facebook.com/intern/diff/15727002/) Pull Request resolved: pytorch/pytorch#21550 Pulled By: driazati Differential Revision: D15727002 fbshipit-source-id: a6974f2e1e5879f0503a18efc138da31cda7afa2
This refactors pybind_utils so we can have all our type-inferring stuff in
1 place (e.g. for #21379)
There is some follow up work to make the error messages better, but I think that's fine to save for another PR.
Differential Revision: D15727002