Skip to content

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Sep 22, 2018

This PR is the start of weak script mode for functions

Weak scripts allow you to compile a graph from Python code at runtime by annotating with torch.jit.weak_script for use in the JIT without affecting eager execution. Scripts are compiled lazily on the first call in a graph to avoid long Python startup times.

@apaszke @zdevito @ezyang

David Riazati added 6 commits September 21, 2018 13:24
Summary:
Weak scripts allow you to compile a graph from Python code at runtime by
annotating with torch.jit.weak_script for use in the JIT without
affecting Python execution. Scripts are compiled lazily on the first
call in a graph to avoid long Python startup times.

```python
    import torch

    @torch.jit.weak_script
    def fn(x):
        return x + 5

    @torch.jit.script
    def run_in_graph(x):
        return fn(x) + 1

    x = torch.randn(4)

    # this should cause fn to compile and run graph
    script_out = run_in_graph(x)

    # this should run fn from Python (without calling into C++)
    expected_out = fn(x)
```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

_VF = torch._C._VariableFunctions

def weak_script(fn):

This comment was marked as off-topic.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments for suggestions about how to structure the code in the compiler. In particular, it should mirror how (strong) script functions flow through the compiler closely.

static Symbol aten(const std::string & s);
static Symbol onnx(const std::string & s);
static Symbol prim(const std::string & s);
static Symbol weak(const std::string & s);

This comment was marked as off-topic.

return std::make_shared<SimpleValue>(v);
}

static py::function compile_func_;

This comment was marked as off-topic.

const Def& def = py::cast<const Def&>(tuple[0]);
ResolutionCallback rcb = py::cast<ResolutionCallback>(tuple[1]);
auto m = std::make_shared<Module>();
defineMethodsInModule(*m, {def}, {pythonResolver(rcb)}, nullptr);

This comment was marked as off-topic.

std::string cconv(inputs.size(), 'd');
Node* new_node = m.graph()->insertNode(m.graph()->createPythonOp(
THPObjectPtr(func.release().ptr()), cconv, {}));

This comment was marked as off-topic.

def weak_script(fn):
# register op for jit
# no op if called from python
fn.is_weak = True

This comment was marked as off-topic.

@ailzhang ailzhang added the oncall: jit Add this issue/PR to JIT oncall triage queue label Sep 25, 2018
David Riazati added 5 commits September 26, 2018 16:57
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@driazati driazati changed the title [jit][WIP] Add weak script mode for script functions [jit] Add weak script mode for script functions Sep 27, 2018

if (is_weak) {
auto compiled_fn =
py::module::import("torch.jit").attr("script")(obj, true, 0, true);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

return fn


def script(fn, optimize=True, _frames_up=0, is_weak=False):

This comment was marked as off-topic.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approach looks good. I have minor comments on how the code is factored.


def script(fn, optimize=True, _frames_up=0):
def weak_script(fn):
fn._jit_is_weak_script = True

This comment was marked as off-topic.

return fn


def script(fn, optimize=True, _frames_up=0, _is_weak=False):

This comment was marked as off-topic.



def weak_script(fn):
compiled_weak_fns[fn] = "pending";

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Oct 4, 2018

I'm also pretty sure that we need to create resolution callbacks at the time when weak_script is called and not when it's compiled. That's what we do for script_methods at least.

if entry is None:
return None
if entry["status"] == COMPILATION_PENDING:
entry["status"] = COMPILED

This comment was marked as off-topic.

entry["status"] = COMPILED
compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
return compiled_fn

This comment was marked as off-topic.

return compiled_fn


def script(fn, optimize=True, _frames_up=0, rcb=None):

This comment was marked as off-topic.

@driazati
Copy link
Contributor Author

driazati commented Oct 4, 2018

@pytorchbot retest this please

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Minor nits to keep the amount of data we hold around lower.

} else if (py::isinstance<py::module>(obj)) {
return std::make_shared<PythonModuleValue>(obj);
}

This comment was marked as off-topic.

compiled_fn = torch.jit.script(fn, True, 0, entry["rcb"])
compiled_weak_fns[fn]["compiled_fn"] = compiled_fn
entry["status"] = COMPILED
return compiled_fn

This comment was marked as off-topic.

compiled_weak_fns[fn] = {
"status": COMPILATION_PENDING,
"compiled_fn": None,
"rcb": createResolutionCallback(_frames_up + 1)

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

driazati has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants