-
Notifications
You must be signed in to change notification settings - Fork 26.3k
extend torch.jit._overload to module methods #24259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Note: This should be landed ONLY after #24259
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
| return self.hello("hi"), self.hello(.5) | ||
|
|
||
| w = CompileOverloadError() | ||
| with self.assertRaisesRegex(Exception, "but instead found type \'str\'"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't a str work here since there's an overload for it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an error here because you're trying to do x + 1 with a str
torch/_jit_internal.py
Outdated
| class_name_map = {} | ||
| _overloaded_methods[qual_name] = class_name_map | ||
|
|
||
| class_name = get_class_name(func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem like a safe way to globally identify classes since it will break for nested classes with the same class name and same function names since the name is not fully qualified
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, i can add a check that the line number is the same. This is still internal-only tho.
|
|
||
| methods = methods + tuple(exported) | ||
|
|
||
| overload_name_mappings = dict(getattr(mod, "__overloads__", {})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you delete the __overloads__ and replace it with these new decorators instead of adding both together? (or do that in a stacked PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea i was going to do that as a follow up, since that likely breaks the ScriptModules in quantization
|
|
||
| for orig_fn, overload_fns in overloads: | ||
| orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule") | ||
| names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns)))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the overloaded methods aren't directly callable why do they need to be added as normal methods with this mangled name to the ScriptModule? The class could hold its schemas for each overload and check that in OverloadedMethodValue instead of what it's doing now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The normal name isn't being added, only the mangled ones are being added.
Defining the mangled ones lazily would be an improvement but I don't think it's necessary for this PR. It would also complicate export logic / possibly other things, and would require removing the other __overload__ mechanism first.
Note: This should be landed ONLY after #24259 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Note: This should be landed ONLY after #24259 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
Note: This should be landed ONLY after #24259 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Pull Request resolved: #24447 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Pull Request resolved: #24447 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Pull Request resolved: #24447 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Pull Request resolved: #24447 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Note: This should be landed ONLY after #24259 Pull Request resolved: #24447 Differential Revision: [D16846006](https://our.internmc.facebook.com/intern/diff/D16846006)
Stack from ghstack:
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.
The usage is:
Differential Revision: D16921304