-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
TLDR
For TorchScript we want to make using script mode as easy as tracing, so we're proposing changing the script API to mirror the tracing API and reduce the code changes necessary to start using TorchScript
a_scripted_module = torch.jit.script(an_nn_module)forwardis compiled by default, methods are lazily compiled and added toa_scripted_moduleas they are called fromforward.- To add methods to
a_scripted_modulethat you don't want to explicitly call, decorate them with@torch.jit.export - To stop the compiler from trying to compile a method, use
@torch.jit.ignore
- To add methods to
torch.jit.Attributeis removed and attributes are added automatically- Type annotations may be necessary
How to use changes
- You shouldn't have to make changes to the original model code to make it compile except making members
Finaland changing the code so its TorchScript compatible - To make a
ScriptModule, pass thenn.Moduletotorch.jit.script(e.g.torch.jit.script(torchvision.models.resnet18()) - Add variables as
torch.jit.Final[T]class annotations instead of adding their names to__constants__ - Attribute type inference works for all types except
Tensors or those containingTensors, but this will be added soon. Attributes that can't be inferred should have their types added as class annotations.
TorchScript API
Currently to convert a nn.Module to a torch.jit.ScriptModule, the nn.Module must be altered to inherit from ScriptModule and @torch.jit.script_method added to all the methods that should be included in compilation. This creates a lot of pressure to modify the original module code in a manner that is difficult to undo.
Modules
- Allow
torch.jit.scriptto take annn.Moduleand produce aScriptModule, or@torch.jit.scripton annn.Module, similar to how functions work today (during development we'd name this something else to not break existing code and allow incremental switching to the new APIs) - Assume
forwardis scriptable, compile it and any other methods that it calls - Deprecate
@torch.jit.script_method - To specify another entry point other than
forward, use the@torch.jit.exportdecorator (the functionality ofexportis the same asscript_method) - To leave functions as Python, use
@torch.jit.ignore- For Python functions that should be removed on export and replaced with an erroring node (for code that is never called, e.g. for training), use
@torch.jit.ignore(drop_on_export=True) @torch.jit.ignorewill now default to the same behavior as@torch.jit.ignore(drop_on_export=False), which will throw an error if you try totorch.jit.save()it
- For Python functions that should be removed on export and replaced with an erroring node (for code that is never called, e.g. for training), use
- Provide APIs in
torch.jit.scriptto specify what is ignored/exported in case the original module code can't be modified
The comments in forward() are what happens when compiling (e.g. someone called my_script_module = torch.jit.script(MyModule()). Most of these are "advanced features", in the typical usage the original module code won't need to be modified at all.
class Submodule(nn.Module):
def __init__(self):
super(Submodule, self).__init__()
def forward(self, x):
return x + 2
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.submodule1 = Submodule()
self.submodule2 = torch.jit.script(Submodule())
@torch.jit.ignore
def some_debugging_function(self, x):
import pdb
pdb.set_trace()
@torch.jit.ignore(drop_on_export=True)
def training_only_code(self, x):
import numpy
# some non-scriptable training code
@torch.jit.export
def an_explicit_entry_point(self, x):
return self.forward(x + 20)
@torch.jit.export
def a_called_model_entry_point(self, x):
return self.forward(x + 20)
def some_function(self, y):
return y + 25
def forward(self, x):
# Compiles submodule1 into a ScriptModule and registers
# it on the module that is being created
x += self.submodule1(x)
# submodule2 is already a ScriptModule
x += self.submodule2(x)
# this compiles some_function and adds it to the ScriptModule
x += self.some_function(x)
if self.training:
x += self.training_only_code(x)
# If this line is not removed, torch.jit.script(MyModule()).save()
# will raise an error
self.some_debugging_function(x)
return x
scripted_module = torch.jit.script(MyModule())
# this saves functions 'forward' and 'an_explicit_entry_point'
# and both submodules
scripted_module.save("out.pt")
# This will compile this function and recursively compile
# anything it needs
scripted_module.a_called_model_entry_point(torch.ones(2, 2))
# this is the same as before but now also includes
# 'a_called_model_entry_point'
scripted_module.save("out.pt")Functions
Functions don't change much, the @ignore and @export APIs can be used on them without raising an error (but it doesn't make much sense to).
# same behavior as before
@torch.jit.script
def some_fn():
return 2
# just marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
return 2
# doesn't do anything, this function is already
# the main entry point
@torch.jit.export
def some_fn3():
return 2Classes
Everything in a user defined class is exported by default, functions can be ignored if needed.
Attributes
Currently attributes are added as below:
class MyModule(torch.jit.ScriptModule):
def __init__(self):
self.my_int_list = torch.jit.Attribute([2], List[int])
self.my_float_list = torch.jit.Attribute([], List[float])
self.my_int = torch.jit.Attribute(2, int)torch.jit.Attribute() is cumbersome and non-standard. This should be changed to infer attribute types when possible and use PEP 526-style annotations when not possible.
class MyModule(torch.jit.ScriptModule):
my_int_list: List[int]
def __init__(self):
self.my_int_list = [2] # the type can be inferred since it has elements
self.my_float_list: List[float] = [] # the type is specified manually
self.my_int_list = [] # can specify the type out of line
self.my_int = 2As in the PEP, this adds to a __annotations__ property. For Python 2 you can specify it manually.
class MyModule(torch.jit.ScriptModule):
__annotations__ = {'my_int_list': List[float]}
def __init__(self):
self.my_int_list = [2] # the type can be inferred since it has elements
self.my_float_list = [] # the type is specified manually
self.my_int = 2Constants
Constants are also currently non-standard. PEP 591 adds a Final type constructor. Class members that are Final can be made into constants.
class MyModule(torch.jit.ScriptModule):
my_constant: Final[int]
def __init__(self):
self.my_constant = 2