Skip to content

Allow partial fallback when frame recompiles failed or exceed the cache size limit #135458

@jerrychenhf

Description

@jerrychenhf

🚀 The feature, motivation and pitch

Currently if the number of recompiles for a frame exceeds the setting of torch._dynamo.config.cache_size_limit, the recompilation will fail and the frame was marked as SKIP_CODE which actually abandon all the existing compiled cache entries associated on the frame extra state. And all following calls to the frame will be run in eager mode no matter whether it was compiled successfully or not.

In large models and complex environments, unless you can set a cache_size_limit to be larger to all the cases, it is possible to get all or nothing situation based on the above handling.
For example, in a model, the actual recompiles for specific frame is possible 9. and the cache_size_limit for any reason is set to 8. We will find the frame will be executed all eager. While one better situation may run the existing 8 compiled cases in compiled graph while fall back to eager for only the last case. Although we know that tuning cache_size_limit to a value to make sure all the 9 cases to be compiled and run in compiled graph is the way to get the best performance here, while considering the complex real cases and usability, it is still very valuable to be able to run the frame partially in compiled graph and eager (eager for cases doesn't have a match in the cache entry, will not try recompile for this case)

So we suggest an improvement to the frame evaluation design to support such handling.
Basically, we can have a setting to control when recompiles or failure happens, should we set the whole frame to SKIP CODE or running with partially without compile. (similar to the run only mode)

To achieve running with partially, we may have a flag in ExtraState data of the frame, instead of abandoning the cache entries and marking the frame as SKIP_CODE, we sets a flag in ExtraState data to mark the frame to be RUN ONLY MODE.

Alternatives

No response

Additional context

The following simple case demonstrate this use case.
For have 3 recompiles but it the cache size limit is set to 2 (just for demo), for such a setting, ideally, we can run y = f(x, g1) and y = f(x, g2) with compiled graph while only y = f(x, g3) in eager.

import torch
from torch._functorch.aot_autograd import aot_module_simplified

torch._dynamo.config.cache_size_limit = 2
torch._dynamo.config.accumulated_cache_size_limit = 512

class MyGraphModule(torch.nn.Module):
    def __init__(self, graph_module):
        super().__init__()
        self._graph_module = graph_module

    def forward(self, *args):
      print("Calling MyGraphModule", id(self))
      return self._graph_module(*args)

def my_backend(gm, sample_inputs):
    def my_compiler(gm, sample_inputs):
        print("Compiling graph:", gm.code)
        mygm = MyGraphModule(gm)
        return mygm.forward
 
    return aot_module_simplified(
        gm,
        sample_inputs,
        fw_compiler=my_compiler
    )
        

def f(x, g):
    x = x * x
    x = x + 1
    out = torch.utils.checkpoint.checkpoint(g, x)
    return out

def g1(x):
    w = x.sin()
    z = w.sin()
    return z

def g2(x):
    w = x.sin()
    z = w.sin()
    return z

def g3(x):
    w = x.sin()
    z = w.sin()
    return z


f = torch._dynamo.optimize(backend=my_backend)(f)

x = torch.ones(2, requires_grad = True)

print("Calling g1", f)
y = f(x, g1)
print("Calling g2", f)
y = f(x, g2)
print("Calling g3", f)
y = f(x, g3)
print("Calling g1", f)
y = f(x, g1)
print("Calling g2", f)
y = f(x, g2)
print("Calling g3", f)
y = f(x, g3)

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec

Metadata

Metadata

Assignees

Labels

module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions