Skip to content

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Aug 6, 2019

This is a series of PRs that will allow us to support adding padding to conv and also reduce the friction of adding method overloads that was brought up in #23266.

Support for overloaded functions following the specification in PEP 484.

The usage is:

@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y

Follow up PRs:

  • Add same API for methods
  • A couple of cleanups for functions:
    • don't require default params specified on the overload as well
    • potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently

Stack from ghstack:

Differential Revision: D16694863

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: pybind Related to our Python bindings / interactions with other Python libraries labels Aug 6, 2019
eellison pushed a commit that referenced this pull request Aug 6, 2019
ghstack-source-id: efbb867
Pull Request resolved: #23886
@eellison eellison requested a review from suo August 6, 2019 18:37
@eellison eellison changed the title add support for overloading functions [JIT] add support for overloading functions Aug 6, 2019
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.
eellison pushed a commit that referenced this pull request Aug 6, 2019
ghstack-source-id: 34ded50
Pull Request resolved: #23886
self.assertEqual(out, torch.tensor(6.0))

def test_function_overloads(self):
# TODO: pyflakes currently does not compose @overload annotation with other
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File an issue to update so this doesn't get forgotten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do when this lands

compiled_fns.append(compiled_fn)

# cache compilation, remove information stored to do compilation
_compiled_overloaded_fns[qual_name] = compiled_fns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a potential memory leak, if the functions this compiles go out of scope, their compiled versions will stick around here. @Chillee found a similar issue recently, I don't remember if he found a fix or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in person: this is not a public api yet, and will have very limited usage here so we are going to leave it as is for now.

This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: int, y: int) -> int: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
@eellison eellison requested a review from driazati August 6, 2019 22:56
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: int, y: int) -> int: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: int, y: int) -> int: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
eellison pushed a commit that referenced this pull request Aug 6, 2019
ghstack-source-id: 84e91f6
Pull Request resolved: #23886
@Chillee
Copy link
Collaborator

Chillee commented Aug 7, 2019

Also, can you add a full example somewhere? I'm having trouble getting this to work.

I have

@torch.jit.overload  # noqa: F811
def identity(x1):  # noqa: F811
    # type: (str) -> str
    pass

@torch.jit.overload  # noqa: F811
def identity(x1=1.0):  # noqa: F811
    # type: (float) -> float
    pass

def identity(x1=1.0):  # noqa: F811
    return x1

print(torch.jit.script(identity))

This throws an error:

RuntimeError: Function <function identity at 0x7fba94ef41e0> cannot be directly compiled because it is overloaded. Please wrap it with typed inputs of the signature you wish to compile instead

Specifically, it's not clear to me what "wrapping it with typed inputs" means.

EDIT: I figured it out by looking at the screenshot you initially sent me. I'd suggest rephrasing the error as something like:

RuntimeError: Function <qualname> cannot be directly compiled because it is overloaded. It must be used in the context of a function where `isinstance` checks can statically determine the type.

Perhaps less technically involved. An example might even be a good idea.

This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
eellison pushed a commit that referenced this pull request Aug 7, 2019
ghstack-source-id: 715ecba
Pull Request resolved: #23886
@eellison
Copy link
Contributor Author

eellison commented Aug 7, 2019

Also, can you add a full example somewhere? I'm having trouble getting this to work.

I have

@torch.jit.overload  # noqa: F811
def identity(x1):  # noqa: F811
    # type: (str) -> str
    pass

@torch.jit.overload  # noqa: F811
def identity(x1=1.0):  # noqa: F811
    # type: (float) -> float
    pass

def identity(x1=1.0):  # noqa: F811
    return x1

print(torch.jit.script(identity))

This throws an error:

RuntimeError: Function <function identity at 0x7fba94ef41e0> cannot be directly compiled because it is overloaded. Please wrap it with typed inputs of the signature you wish to compile instead

Specifically, it's not clear to me what "wrapping it with typed inputs" means.

EDIT: I figured it out by looking at the screenshot you initially sent me. I'd suggest rephrasing the error as something like:

RuntimeError: Function <qualname> cannot be directly compiled because it is overloaded. It must be used in the context of a function where `isinstance` checks can statically determine the type.

Perhaps less technically involved. An example might even be a good idea.

Thanks for the comments. Updated the error message, let me know if you think it's easier to understand. Btw, I suspect this is a very uncommon case (compiling a single function from nn/functional by itself).

Copy link
Contributor

@driazati driazati left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, remaining stuff is mostly just structural

def schema_match_failure():
return identity((1, 2))

thrown = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not self.assertRaisesRegex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't figure out the regex lol

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would look something like this if you wanna do it with regex (gotta use lookaheads though): https://regexr.com/4irtv

(?=.*of type 'str')(?=.*of type 'float').*

Might not be worth it.

This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
 


Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
 


Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in  #23266.

Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading).

The usage is:
```
@torch.jit.overload
def add(x: int, y: int) -> int: ...
@torch.jit.overload
def add(x: float, y: float) -> float: ...

def add:
    return x + y
```

Follow up PRs:

- Add same API for methods
- A couple of cleanups for functions:
     - don't require default params specified on the overload as well
     - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
 


Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
@zou3519 zou3519 deleted the gh/eellison/8/head branch August 8, 2019 02:21

py::bool_ isFunction = py::module::import("inspect").attr("isfunction")(obj);
if (py::cast<bool>(isFunction)) {
auto overloads =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the infrastructure for overloaded functions is only implemented for Python. Do we have a plan for resolving overloaded functions defined from the string frontend (for example, how will we serialize overloaded functions)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any overloaded function that needs to be serialized will already be instantiated in one of its overloads, that function gets serialized normally.

qual_name = _qualified_name(obj)
global _compiled_overloaded_fns
compiled_overloads = _compiled_overloaded_fns.get(qual_name, None)
if compiled_overloads is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work if we call _get_overloads() then add another overload after. The second overload will not ever get compiled. Do we need a cache here? We don't have anything similar for other recursively compiled functions.

Copy link
Contributor Author

@eellison eellison Aug 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use case you are imagining? This is still an internal only api so I don't think it applies currently but something to think about if we make it public. We don't strictly need a cache it's true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add a new overload after using the overload in the scripted function, it will not be registered. Let's remove the cache, or make the existing caching logic correct. Even if it's unlikely I think it'll be very surprising if we ever run into this.

[](Module& self, std::string name, TypePtr type, py::object value) {
auto unshaped = unshapedType(type);
self.register_attribute(name, unshaped, toIValue(value, type));
self.register_attribute(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this changed?

Copy link
Contributor Author

@eellison eellison Aug 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clang tidy complained, and i accidentally removed the unshaped

@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in 451fc51.

eellison pushed a commit that referenced this pull request Aug 15, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 15, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 16, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```
eellison pushed a commit that referenced this pull request Aug 20, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```

Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
eellison pushed a commit that referenced this pull request Aug 20, 2019
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
@torch.jit.overload
def add(self, y: int) -> int: ...
@torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ... 
```

Differential Revision: [D16921304](https://our.internmc.facebook.com/intern/diff/D16921304)
facebook-github-bot pushed a commit that referenced this pull request Aug 20, 2019
Summary:
Pull Request resolved: #24259

Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266.

The usage is:
```
torch.jit.overload
def add(self, y: int) -> int: ...
torch.jit.overload
def add(self, y: float) -> float: ...
def add():
   ...
```

Test Plan: Imported from OSS

Differential Revision: D16921304

Pulled By: eellison

fbshipit-source-id: 784e2f26f7ca9a330a434a603c86b53725c3dc71
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: pybind Related to our Python bindings / interactions with other Python libraries oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants