-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Add support for @staticmethod
#27163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Resolve static methods as functions Fixes #26792
|
Going forward I'm not sure if we should implement a feature for |
|
@eellison it should work for class types too, added a test |
|
Could you add a tests for when two different classes define a static method with the same name, and those two methods are used in the same function ? |
|
@pytorchbot rebase this please |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good i just have a couple questions / nits
| retval = resolver->resolveValue(ident, method, range); | ||
| } | ||
|
|
||
| if (!retval) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this change made?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iirc classes can be resolved both by resolveValue and resolveType, so this makes sure resolveValue goes first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still dont really understand why this was changed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Static methods aren't put on to class types but resolved into functions directly from Python, so if the class is resolved as a type then it won't find the static methods
| return x * 100 | ||
|
|
||
| def forward(self, x): | ||
| return x - M.my_method(x) + N.my_method(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this serialize if we're not inlining methods ? Could you print out the source ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still do inlining:
graph(%self : ClassType<N>,
%x.1 : Tensor):
%23 : int = prim::Constant[value=100]() # test/test_jit.py:3654:27
%22 : int = prim::Constant[value=1]()
%24 : Tensor = aten::add(%x.1, %23, %22) # test/test_jit.py:3654:23
%7 : Tensor = aten::sub(%x.1, %24, %22) # test/test_jit.py:3668:23
%26 : Tensor = aten::mul(%x.1, %23) # test/test_jit.py:3665:23
%12 : Tensor = aten::add(%7, %26, %22) # test/test_jit.py:3668:23
return (%12)
With inlining off the methods get correctly mangled:
graph(%self : ClassType<N>,
%x.1 : Tensor):
%9 : Function = prim::Constant[name="my_method"]()
%6 : int = prim::Constant[value=1]()
%4 : Function = prim::Constant[name="my_method"]()
%5 : Tensor = prim::CallFunction(%4, %x.1) # test/test_jit.py:3668:27
%7 : Tensor = aten::sub(%x.1, %5, %6) # test/test_jit.py:3668:23
%10 : Tensor = prim::CallFunction(%9, %x.1) # test/test_jit.py:3668:44
%12 : Tensor = aten::add(%7, %10, %6) # test/test_jit.py:3668:23
return (%12)
Archive: a.pt
extracting: a/version
extracting: a/data.pkl
inflating: a/code/__torch__.py
inflating: a/code/__torch__.py.debug_pkl
inflating: a/code/__torch__/___torch_mangle_0.py
inflating: a/code/__torch__/___torch_mangle_0.py.debug_pkl
inflating: a/code/__torch__/___torch_mangle_1.py
inflating: a/code/__torch__/___torch_mangle_1.py.debug_pkl
extracting: a/constants.pkl
[dev] ~/d/pytorch > cat ^C
[dev] ~/d/pytorch > cd ^C
[dev] ~/d/pytorch > cat a/code/__torch__/___torch_mangle_0.py (base) [ 15:22:38 ]
op_version_set = 1
def my_method(x: Tensor) -> Tensor:
return torch.add(x, 100, 1)
[dev] ~/d/pytorch > cat a/code/__torch__/___torch_mangle_1.py (base) [ 15:22:50 ]
op_version_set = 1
def my_method(x: Tensor) -> Tensor:
return torch.mul(x, 100)
[dev] ~/d/pytorch > (base) [ 15:22:53 ]
and the tests still work
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
| retval = resolver->resolveValue(ident, method, range); | ||
| } | ||
|
|
||
| if (!retval) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still dont really understand why this was changed
Summary: Type objects in python have an attribute `__abstractmethods__` that throws when it is accessed, so we were failing with an AttributeError whenever a type was used in TorchScript. This pr prevents that error from happening. We can't just throw when a type is used because it could be used to access a static method: #27163 Pull Request resolved: #28053 Differential Revision: D18332347 Pulled By: eellison fbshipit-source-id: 9c7f2220f92674ad4d903621d9762cecc566ab0d
Summary: Resolve static methods as functions Fixes pytorch#26792 ](https://our.intern.facebook.com/intern/diff/17695094/) Pull Request resolved: pytorch#27163 Pulled By: driazati Differential Revision: D17695094 fbshipit-source-id: 4671cae1a92526a35c83b8d9c12a50aa5442412b
Resolve static methods as functions
Fixes #26792
Differential Revision: D17695094