-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Currently if the number of recompiles for a frame exceeds the setting of torch._dynamo.config.cache_size_limit, the recompilation will fail and the frame was marked as SKIP_CODE which actually abandon all the existing compiled cache entries associated on the frame extra state. And all following calls to the frame will be run in eager mode no matter whether it was compiled successfully or not.
In large models and complex environments, unless you can set a cache_size_limit to be larger to all the cases, it is possible to get all or nothing situation based on the above handling.
For example, in a model, the actual recompiles for specific frame is possible 9. and the cache_size_limit for any reason is set to 8. We will find the frame will be executed all eager. While one better situation may run the existing 8 compiled cases in compiled graph while fall back to eager for only the last case. Although we know that tuning cache_size_limit to a value to make sure all the 9 cases to be compiled and run in compiled graph is the way to get the best performance here, while considering the complex real cases and usability, it is still very valuable to be able to run the frame partially in compiled graph and eager (eager for cases doesn't have a match in the cache entry, will not try recompile for this case)
So we suggest an improvement to the frame evaluation design to support such handling.
Basically, we can have a setting to control when recompiles or failure happens, should we set the whole frame to SKIP CODE or running with partially without compile. (similar to the run only mode)
To achieve running with partially, we may have a flag in ExtraState data of the frame, instead of abandoning the cache entries and marking the frame as SKIP_CODE, we sets a flag in ExtraState data to mark the frame to be RUN ONLY MODE.
Alternatives
No response
Additional context
The following simple case demonstrate this use case.
For have 3 recompiles but it the cache size limit is set to 2 (just for demo), for such a setting, ideally, we can run y = f(x, g1) and y = f(x, g2) with compiled graph while only y = f(x, g3) in eager.
import torch
from torch._functorch.aot_autograd import aot_module_simplified
torch._dynamo.config.cache_size_limit = 2
torch._dynamo.config.accumulated_cache_size_limit = 512
class MyGraphModule(torch.nn.Module):
def __init__(self, graph_module):
super().__init__()
self._graph_module = graph_module
def forward(self, *args):
print("Calling MyGraphModule", id(self))
return self._graph_module(*args)
def my_backend(gm, sample_inputs):
def my_compiler(gm, sample_inputs):
print("Compiling graph:", gm.code)
mygm = MyGraphModule(gm)
return mygm.forward
return aot_module_simplified(
gm,
sample_inputs,
fw_compiler=my_compiler
)
def f(x, g):
x = x * x
x = x + 1
out = torch.utils.checkpoint.checkpoint(g, x)
return out
def g1(x):
w = x.sin()
z = w.sin()
return z
def g2(x):
w = x.sin()
z = w.sin()
return z
def g3(x):
w = x.sin()
z = w.sin()
return z
f = torch._dynamo.optimize(backend=my_backend)(f)
x = torch.ones(2, requires_grad = True)
print("Calling g1", f)
y = f(x, g1)
print("Calling g2", f)
y = f(x, g2)
print("Calling g3", f)
y = f(x, g3)
print("Calling g1", f)
y = f(x, g1)
print("Calling g2", f)
y = f(x, g2)
print("Calling g3", f)
y = f(x, g3)
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames @rec