Skip to content

[jit] Add lazy script decorator#34935

Closed
driazati wants to merge 6 commits intomasterfrom
driazati/lazyscript
Closed

[jit] Add lazy script decorator#34935
driazati wants to merge 6 commits intomasterfrom
driazati/lazyscript

Conversation

@driazati
Copy link
Contributor

@driazati driazati commented Mar 18, 2020

Stacked PRs

Some users maintain libraries of code that is largely trace-able but not
script-able. However, some functions may need to be @torch.jit.scripted if
they contain control flow so the tracer will use the compiler version.
This however impacts library start up time as in #33418, so this PR adds
a workaround in the form of a @torch.jit. _script_if_tracing
that will only initialize the compiler if the function is called while
actually tracing.

Differential Revision: D20569778

Some users maintain libraries of code that is largely trace-able but not
script-able. However, some functions may need to be `@torch.jit.script`ed if
they contain control flow so the tracer will use the compiler version.
This however impacts library start up time as in #33418, so this PR adds
a workaround in the form of a `@torch.jit._lazy_script_while_tracing`
that will only initialize the compiler if the function is called while
actually tracing.
@driazati driazati requested a review from apaszke as a code owner March 18, 2020 02:12
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 18, 2020
Your Name added 2 commits March 17, 2020 19:25
@driazati driazati requested review from eellison and suo March 18, 2020 02:34
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. One of your tests was wrong so if that's not working this may not work right now.

def fn2(x):
return untraceable(x)

with self.capture_stdout():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why capture stdout? you can @suppress_warnings if warnings are being supressed ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are prints in the test

test/test_jit.py Outdated
with self.capture_stdout():
traced = torch.jit.trace(fn, [torch.ones(2, 2)])

FileCheck().check_not("goodbye").check_not("hello").run(traced.graph)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be check and not check_not ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, this was only passing since the inline mode was wrong

@eellison
Copy link
Contributor

Also, not sure _lazy_script is the most clear. Maybe: jit._script_if_tracing ? jit._compile_if_tracing ?

@driazati
Copy link
Contributor Author

_script_if_tracing sounds best, I'll change it

Your Name added 2 commits March 19, 2020 17:15
facebook-github-bot pushed a commit that referenced this pull request Mar 24, 2020
Summary:
Stacked PRs
 * **#34938 - [jit] Remove stray `script`**
 * #34935 - [jit] Add lazy script decorator

Pull Request resolved: #34938

Pulled By: driazati

Differential Revision: D20569793

fbshipit-source-id: 1f126646f7bd7c4ea972e15023eaa60f0e301351
@facebook-github-bot
Copy link
Contributor

@driazati merged this pull request in 44622bb.

Returns ``True`` in tracing (if a function is called during the tracing of
code with ``torch.jit.trace``) and ``False`` otherwise.
"""
return torch._C._is_tracing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns the function, and not the bool torch._C._is_tracing(). _script_if_tracing doesn't work on 1.6.0 but is fixed on master.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed here for reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants