-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] add support for overloading functions #23886
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| self.assertEqual(out, torch.tensor(6.0)) | ||
|
|
||
| def test_function_overloads(self): | ||
| # TODO: pyflakes currently does not compose @overload annotation with other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File an issue to update so this doesn't get forgotten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do when this lands
| compiled_fns.append(compiled_fn) | ||
|
|
||
| # cache compilation, remove information stored to do compilation | ||
| _compiled_overloaded_fns[qual_name] = compiled_fns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a potential memory leak, if the functions this compiles go out of scope, their compiled versions will stick around here. @Chillee found a similar issue recently, I don't remember if he found a fix or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed in person: this is not a public api yet, and will have very limited usage here so we are going to leave it as is for now.
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
|
Also, can you add a full example somewhere? I'm having trouble getting this to work. I have This throws an error: Specifically, it's not clear to me what "wrapping it with typed inputs" means. EDIT: I figured it out by looking at the screenshot you initially sent me. I'd suggest rephrasing the error as something like: Perhaps less technically involved. An example might even be a good idea. |
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
Thanks for the comments. Updated the error message, let me know if you think it's easier to understand. Btw, I suspect this is a very uncommon case (compiling a single function from nn/functional by itself). |
driazati
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, remaining stuff is mostly just structural
| def schema_match_failure(): | ||
| return identity((1, 2)) | ||
|
|
||
| thrown = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not self.assertRaisesRegex?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't figure out the regex lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would look something like this if you wanna do it with regex (gotta use lookaheads though): https://regexr.com/4irtv
(?=.*of type 'str')(?=.*of type 'float').*
Might not be worth it.
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
|
|
||
| py::bool_ isFunction = py::module::import("inspect").attr("isfunction")(obj); | ||
| if (py::cast<bool>(isFunction)) { | ||
| auto overloads = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the infrastructure for overloaded functions is only implemented for Python. Do we have a plan for resolving overloaded functions defined from the string frontend (for example, how will we serialize overloaded functions)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any overloaded function that needs to be serialized will already be instantiated in one of its overloads, that function gets serialized normally.
| qual_name = _qualified_name(obj) | ||
| global _compiled_overloaded_fns | ||
| compiled_overloads = _compiled_overloaded_fns.get(qual_name, None) | ||
| if compiled_overloads is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work if we call _get_overloads() then add another overload after. The second overload will not ever get compiled. Do we need a cache here? We don't have anything similar for other recursively compiled functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the use case you are imagining? This is still an internal only api so I don't think it applies currently but something to think about if we make it public. We don't strictly need a cache it's true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add a new overload after using the overload in the scripted function, it will not be registered. Let's remove the cache, or make the existing caching logic correct. Even if it's unlikely I think it'll be very surprising if we ever run into this.
| [](Module& self, std::string name, TypePtr type, py::object value) { | ||
| auto unshaped = unshapedType(type); | ||
| self.register_attribute(name, unshaped, toIValue(value, type)); | ||
| self.register_attribute( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang tidy complained, and i accidentally removed the unshaped
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
Summary: Pull Request resolved: #24259 Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` torch.jit.overload def add(self, y: int) -> int: ... torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Test Plan: Imported from OSS Differential Revision: D16921304 Pulled By: eellison fbshipit-source-id: 784e2f26f7ca9a330a434a603c86b53725c3dc71
This is a series of PRs that will allow us to support adding padding to conv and also reduce the friction of adding method overloads that was brought up in #23266.
Support for overloaded functions following the specification in PEP 484.
The usage is:
Follow up PRs:
Stack from ghstack:
Differential Revision: D16694863