-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Add weak script mode for script functions #11963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary:
Weak scripts allow you to compile a graph from Python code at runtime by
annotating with torch.jit.weak_script for use in the JIT without
affecting Python execution. Scripts are compiled lazily on the first
call in a graph to avoid long Python startup times.
```python
import torch
@torch.jit.weak_script
def fn(x):
return x + 5
@torch.jit.script
def run_in_graph(x):
return fn(x) + 1
x = torch.randn(4)
# this should cause fn to compile and run graph
script_out = run_in_graph(x)
# this should run fn from Python (without calling into C++)
expected_out = fn(x)
```
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
torch/nn/functional.py
Outdated
|
|
||
| _VF = torch._C._VariableFunctions | ||
|
|
||
| def weak_script(fn): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments for suggestions about how to structure the code in the compiler. In particular, it should mirror how (strong) script functions flow through the compiler closely.
torch/csrc/jit/interned_strings.h
Outdated
| static Symbol aten(const std::string & s); | ||
| static Symbol onnx(const std::string & s); | ||
| static Symbol prim(const std::string & s); | ||
| static Symbol weak(const std::string & s); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/init.cpp
Outdated
| return std::make_shared<SimpleValue>(v); | ||
| } | ||
|
|
||
| static py::function compile_func_; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/init.cpp
Outdated
| const Def& def = py::cast<const Def&>(tuple[0]); | ||
| ResolutionCallback rcb = py::cast<ResolutionCallback>(tuple[1]); | ||
| auto m = std::make_shared<Module>(); | ||
| defineMethodsInModule(*m, {def}, {pythonResolver(rcb)}, nullptr); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/script/init.cpp
Outdated
| std::string cconv(inputs.size(), 'd'); | ||
| Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp( | ||
| THPObjectPtr(func.release().ptr()), cconv, {})); | ||
|
|
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| def weak_script(fn): | ||
| # register op for jit | ||
| # no op if called from python | ||
| fn.is_weak = True |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
torch/csrc/jit/script/init.cpp
Outdated
|
|
||
| if (is_weak) { | ||
| auto compiled_fn = | ||
| py::module::import("torch.jit").attr("script")(obj, true, 0, true); |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| return fn | ||
|
|
||
|
|
||
| def script(fn, optimize=True, _frames_up=0, is_weak=False): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approach looks good. I have minor comments on how the code is factored.
torch/jit/__init__.py
Outdated
|
|
||
| def script(fn, optimize=True, _frames_up=0): | ||
| def weak_script(fn): | ||
| fn._jit_is_weak_script = True |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| return fn | ||
|
|
||
|
|
||
| def script(fn, optimize=True, _frames_up=0, _is_weak=False): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
|
|
||
|
|
||
| def weak_script(fn): | ||
| compiled_weak_fns[fn] = "pending"; |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
I'm also pretty sure that we need to create resolution callbacks at the time when |
torch/jit/__init__.py
Outdated
| if entry is None: | ||
| return None | ||
| if entry["status"] == COMPILATION_PENDING: | ||
| entry["status"] = COMPILED |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| entry["status"] = COMPILED | ||
| compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"]) | ||
| compiled_weak_fns[fn]["compiled_fn"] = compiled_fn | ||
| return compiled_fn |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/jit/__init__.py
Outdated
| return compiled_fn | ||
|
|
||
|
|
||
| def script(fn, optimize=True, _frames_up=0, rcb=None): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@pytorchbot retest this please |
zdevito
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good. Minor nits to keep the amount of data we hold around lower.
| } else if (py::isinstance<py::module>(obj)) { | ||
| return std::make_shared<PythonModuleValue>(obj); | ||
| } | ||
|
|
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"]) | ||
| compiled_weak_fns[fn]["compiled_fn"] = compiled_fn | ||
| entry["status"] = COMPILED | ||
| return compiled_fn |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| compiled_weak_fns[fn] = { | ||
| "status": COMPILATION_PENDING, | ||
| "compiled_fn": None, | ||
| "rcb": createResolutionCallback(_frames_up + 1) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This PR is the start of weak script mode for functions
Weak scripts allow you to compile a graph from Python code at runtime by annotating with
torch.jit.weak_scriptfor use in the JIT without affecting eager execution. Scripts are compiled lazily on the first call in a graph to avoid long Python startup times.@apaszke @zdevito @ezyang