Skip to content

[POC] Convolution-relu fusion via torch function and one-step lazy evaluation #65753

@ezyang

Description

@ezyang

This code sample shows how to use torch function to do automatic eager mode convolution-relu fusion with "one step" of lazy evaluation. It needs a lot of improving and there are some internal APIs we can make better, but this is a use case for torch function/dispatch that we ought to be paying attention to.

import torch

# TODO: tree_map is pretty slow, once you figure out what API you want
# there is easy optimizations available
from torch.utils._pytree import tree_map

# TODO: Obviously you will want to generalize this to work for arbitrary
# operators
class DelayedConv2d:
    def __init__(self, args, kwargs):
        self.args = args
        self.kwargs = kwargs
        self.val = None

    def force(self):
        if self.val is None:
            self.val = torch.conv2d(*self.args, **self.kwargs)
            self.args = None
            self.kwargs = None
        return self.val

    # NB: Using __torch_function__ here means that if you want to identify
    # quantization fusions in backwards, this is not possible.  I don't
    # think you need it.
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func is torch.relu:
            self, = args
            assert isinstance(self, DelayedConv2d)
            if self.val is not None:
                return torch.relu(self.val)
            else:
                print("fusion is possible here!")
                return torch.relu(self.force())

        # no fusion possible, just force everything
        def unwrap(t):
            if isinstance(t, DelayedConv2d):
                return t.force()
            else:
                return t

        return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

# TODO: Tracer should probably be a mode; this would avoid having to deal with
# wrapping/unwrapping tracers as happens here.  Other possibilities are for this
# to be a tensor subclass (which means that methods like tracer.relu() would work)
class Tracer:
    def __init__(self, val):
        self.val = val

    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        def wrap(t):
            return Tracer(t)

        def unwrap(t):
            if isinstance(t, Tracer):
                return t.val
            else:
                return t

        if func is torch.conv2d:
            return DelayedConv2d(tree_map(unwrap, args), tree_map(unwrap, kwargs))
        else:
            r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
            return tree_map(wrap, r)

filters = torch.randn(8, 4, 3, 3)
inputs = Tracer(torch.randn(1, 4, 5, 5))

result = torch.nn.functional.conv2d(inputs, filters, padding=1)
print(torch.relu(result).size())

cc @hameerabbasi @rgommers @peterbell10

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: __torch_function__module: lazytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions