-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Plumb type annotations through script compilation #9405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ce55be3 to
f88188c
Compare
f88188c to
8a01c6b
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jamesr66a has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool -- I have a bunch of organization comments to keep the complexity of this change down. Right now it adds too many different pathways to the compiler and I have some suggestions about how to reduce that.
| DefAndTypes(Def def, std::vector<TypePtr> arg_types, TypePtr return_type) | ||
| : def(std::move(def)), arg_types(arg_types), return_type(return_type) {} | ||
| Def def; | ||
| std::vector<TypePtr> arg_types; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| bool allow_varargs; | ||
| }; | ||
|
|
||
| struct DefAndTypes { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| const std::vector<Resolver>& resolvers, /* determines how we handle free variables in each definition*/ | ||
| std::shared_ptr<SugaredValue> self /* if non-null, the first argument to each def, is bound to this value */ | ||
| std::shared_ptr<SugaredValue> self, /* if non-null, the first argument to each def, is bound to this value */ | ||
| bool pure_func=false /* If true, expect a single def which will become the 'forward' method on the module */ |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| mod = ScriptModule() | ||
| rcb = createResolutionCallback(_frames_up + 1) | ||
| ast = get_jit_ast(fn) | ||
| arg_types, ret_type = annotations.get_signature(fn, ast.num_params(), 0) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| graph = _script_graph(fn, _frames_up=_frames_up + 1) | ||
| mod = ScriptModule() | ||
| mod._create_method_from_graph('forward', graph) | ||
| _script_pure_function(mod, fn, _frames_up=_frames_up + 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| auto& name = (*it).ident().name(); | ||
| arguments.push_back({name, DynamicType::get()}); | ||
| TypePtr arg_type = DynamicType::get(); | ||
| if (def_and_types.arg_types.size()) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| ast = get_jit_ast(fn) | ||
| arg_types, return_type = annotations.get_signature(fn, ast.num_params() - 1, 0) | ||
| # Dumb handling for `self` | ||
| if len(arg_types) == ast.num_params(): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| torch._C.ScriptModule.__init__(self) | ||
| original_init(self, *args, **kwargs) | ||
| asts = [m.ast for m in methods] | ||
| defs = [torch._C.DefAndTypes(m.ast, m.arg_types, m.return_type) for m in methods] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| for(Def def : definitions) { | ||
| const std::string& name = def.name().name(); | ||
| for(DefAndTypes def : definitions) { | ||
| const std::string& name = pure_func ? "forward" : def.def.name().name(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| std::vector<TypePtr> flattened_return_types; | ||
| if (def_and_types.return_type) { | ||
| if (def_and_types.return_type->kind() == TypeKind::TupleType) { | ||
| const auto &tuple_type_elmts = def_and_types.return_type->cast<TupleType>()->elements(); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| wrap_list(r, std::move(body))); | ||
| })); | ||
| })) | ||
| .def("num_params", [](Def& self) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| FunctionTable table; | ||
| JIT_ASSERT(definitions.size() == resolvers.size()); | ||
| if (pure_func) { | ||
| JIT_ASSERT(definitions.size() == 1); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| bool allow_varargs; | ||
| }; | ||
|
|
||
| struct DefAndTypes { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
|
|
||
| def _script_graph(fn, _frames_up=0): | ||
| def _compile_fn(fn, _frames_up=0): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| _jit_script_compile_pure_fn(mod, torch._C.DefAndTypes(ast, arg_types, ret_type), rcb) | ||
|
|
||
|
|
||
| def _script_graph(fn, _frames_up=0): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
|
||
| m.def("_jit_script_compile", [](Def def, ResolutionCallback rcb) { | ||
| return compileFunction(def, pythonResolver(rcb)); | ||
| m.def("_jit_script_compile_pure_fn", [](Module &m, DefAndTypes def, ResolutionCallback rcb) { |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Superseded by #9547 |
Summary: Supersedes pytorch#9405 Pull Request resolved: pytorch#9547 Reviewed By: zdevito Differential Revision: D8900327 Pulled By: jamesr66a fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
Summary: Supersedes pytorch#9405 Pull Request resolved: pytorch#9547 Reviewed By: zdevito Differential Revision: D8900327 Pulled By: jamesr66a fbshipit-source-id: a00a94615af4fbaec98ee3ede0cb54bcfd9108dd
No description provided.