Skip to content

[jit] Changes to TorchScript API #20939

@driazati

Description

@driazati

TLDR

For TorchScript we want to make using script mode as easy as tracing, so we're proposing changing the script API to mirror the tracing API and reduce the code changes necessary to start using TorchScript

a_scripted_module = torch.jit.script(an_nn_module)
  • forward is compiled by default, methods are lazily compiled and added to a_scripted_module as they are called from forward.
    • To add methods to a_scripted_module that you don't want to explicitly call, decorate them with @torch.jit.export
    • To stop the compiler from trying to compile a method, use @torch.jit.ignore
  • torch.jit.Attribute is removed and attributes are added automatically
    • Type annotations may be necessary

How to use changes

  • You shouldn't have to make changes to the original model code to make it compile except making members Final and changing the code so its TorchScript compatible
  • To make a ScriptModule, pass the nn.Module to torch.jit.script (e.g. torch.jit.script(torchvision.models.resnet18())
  • Add variables as torch.jit.Final[T] class annotations instead of adding their names to __constants__
  • Attribute type inference works for all types except Tensors or those containing Tensors, but this will be added soon. Attributes that can't be inferred should have their types added as class annotations.

TorchScript API

Currently to convert a nn.Module to a torch.jit.ScriptModule, the nn.Module must be altered to inherit from ScriptModule and @torch.jit.script_method added to all the methods that should be included in compilation. This creates a lot of pressure to modify the original module code in a manner that is difficult to undo.

Modules

  • Allow torch.jit.script to take an nn.Module and produce a ScriptModule, or @torch.jit.script on an nn.Module, similar to how functions work today (during development we'd name this something else to not break existing code and allow incremental switching to the new APIs)
  • Assume forward is scriptable, compile it and any other methods that it calls
  • Deprecate @torch.jit.script_method
  • To specify another entry point other than forward, use the @torch.jit.export decorator (the functionality of export is the same as script_method)
  • To leave functions as Python, use @torch.jit.ignore
    • For Python functions that should be removed on export and replaced with an erroring node (for code that is never called, e.g. for training), use @torch.jit.ignore(drop_on_export=True)
    • @torch.jit.ignore will now default to the same behavior as @torch.jit.ignore(drop_on_export=False), which will throw an error if you try to torch.jit.save() it
  • Provide APIs in torch.jit.script to specify what is ignored/exported in case the original module code can't be modified

The comments in forward() are what happens when compiling (e.g. someone called my_script_module = torch.jit.script(MyModule()). Most of these are "advanced features", in the typical usage the original module code won't need to be modified at all.

class Submodule(nn.Module):
  def __init__(self):
    super(Submodule, self).__init__()

  def forward(self, x):
    return x + 2

class MyModule(nn.Module):
  def __init__(self):
    super(MyModule, self).__init__()
      self.submodule1 = Submodule()
      self.submodule2 = torch.jit.script(Submodule())

  @torch.jit.ignore
  def some_debugging_function(self, x):
    import pdb
    pdb.set_trace()

  @torch.jit.ignore(drop_on_export=True)
  def training_only_code(self, x):
    import numpy
    # some non-scriptable training code

  @torch.jit.export
  def an_explicit_entry_point(self, x):
    return self.forward(x + 20)

  @torch.jit.export
  def a_called_model_entry_point(self, x):
    return self.forward(x + 20)

  def some_function(self, y):
    return y + 25

  def forward(self, x):
    # Compiles submodule1 into a ScriptModule and registers
    # it on the module that is being created
    x += self.submodule1(x)

    # submodule2 is already a ScriptModule
    x += self.submodule2(x)

    # this compiles some_function and adds it to the ScriptModule
    x += self.some_function(x)

    if self.training:
      x += self.training_only_code(x)

    # If this line is not removed, torch.jit.script(MyModule()).save()
    # will raise an error
    self.some_debugging_function(x)
    return x

scripted_module = torch.jit.script(MyModule())

# this saves functions 'forward' and 'an_explicit_entry_point'
# and both submodules
scripted_module.save("out.pt")

# This will compile this function and recursively compile
# anything it needs
scripted_module.a_called_model_entry_point(torch.ones(2, 2))

# this is the same as before but now also includes
# 'a_called_model_entry_point'
scripted_module.save("out.pt")

Functions

Functions don't change much, the @ignore and @export APIs can be used on them without raising an error (but it doesn't make much sense to).

# same behavior as before
@torch.jit.script
def some_fn():
    return 2

# just marks a function as ignored, if nothing
# ever calls it then this has no effect
@torch.jit.ignore
def some_fn2():
    return 2

# doesn't do anything, this function is already 
# the main entry point
@torch.jit.export
def some_fn3():
    return 2

Classes

Everything in a user defined class is exported by default, functions can be ignored if needed.

Attributes

Currently attributes are added as below:

class MyModule(torch.jit.ScriptModule):
    def __init__(self):
    self.my_int_list = torch.jit.Attribute([2], List[int])
    self.my_float_list = torch.jit.Attribute([], List[float])
    self.my_int = torch.jit.Attribute(2, int)

torch.jit.Attribute() is cumbersome and non-standard. This should be changed to infer attribute types when possible and use PEP 526-style annotations when not possible.

class MyModule(torch.jit.ScriptModule):
  my_int_list: List[int]
 
  def __init__(self):
    self.my_int_list = [2]  # the type can be inferred since it has elements
    self.my_float_list: List[float] = []  # the type is specified manually
    self.my_int_list = [] # can specify the type out of line
    self.my_int = 2

As in the PEP, this adds to a __annotations__ property. For Python 2 you can specify it manually.

class MyModule(torch.jit.ScriptModule):
  __annotations__ = {'my_int_list': List[float]}
 
  def __init__(self):
    self.my_int_list = [2]  # the type can be inferred since it has elements
    self.my_float_list = []  # the type is specified manually
    self.my_int = 2

Constants

Constants are also currently non-standard. PEP 591 adds a Final type constructor. Class members that are Final can be made into constants.

class MyModule(torch.jit.ScriptModule):
  my_constant: Final[int]
 
    def __init__(self):
        self.my_constant = 2

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions