-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Open
Labels
module: __torch_function__module: lazytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
This code sample shows how to use torch function to do automatic eager mode convolution-relu fusion with "one step" of lazy evaluation. It needs a lot of improving and there are some internal APIs we can make better, but this is a use case for torch function/dispatch that we ought to be paying attention to.
import torch
# TODO: tree_map is pretty slow, once you figure out what API you want
# there is easy optimizations available
from torch.utils._pytree import tree_map
# TODO: Obviously you will want to generalize this to work for arbitrary
# operators
class DelayedConv2d:
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs
self.val = None
def force(self):
if self.val is None:
self.val = torch.conv2d(*self.args, **self.kwargs)
self.args = None
self.kwargs = None
return self.val
# NB: Using __torch_function__ here means that if you want to identify
# quantization fusions in backwards, this is not possible. I don't
# think you need it.
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.relu:
self, = args
assert isinstance(self, DelayedConv2d)
if self.val is not None:
return torch.relu(self.val)
else:
print("fusion is possible here!")
return torch.relu(self.force())
# no fusion possible, just force everything
def unwrap(t):
if isinstance(t, DelayedConv2d):
return t.force()
else:
return t
return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
# TODO: Tracer should probably be a mode; this would avoid having to deal with
# wrapping/unwrapping tracers as happens here. Other possibilities are for this
# to be a tensor subclass (which means that methods like tracer.relu() would work)
class Tracer:
def __init__(self, val):
self.val = val
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def wrap(t):
return Tracer(t)
def unwrap(t):
if isinstance(t, Tracer):
return t.val
else:
return t
if func is torch.conv2d:
return DelayedConv2d(tree_map(unwrap, args), tree_map(unwrap, kwargs))
else:
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
return tree_map(wrap, r)
filters = torch.randn(8, 4, 3, 3)
inputs = Tracer(torch.randn(1, 4, 5, 5))
result = torch.nn.functional.conv2d(inputs, filters, padding=1)
print(torch.relu(result).size())Metadata
Metadata
Assignees
Labels
module: __torch_function__module: lazytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module