-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
repro:
@torch.compile(backend="aot_eager", dynamic=True)
def f(x):
if x.shape[0] > 5:
return x.sin()
else:
return x.cos()
x1 = torch.randn(4)
x2 = torch.randn(6)
x3 = torch.randn(3)
with torch.profiler.profile(record_shapes=True) as prof:
out1 = f(x1)
out2 = f(x2)
out3 = f(x3)
print([e.kwinputs['context'] for e in prof.events() if 'Compiled' in e.name])
This invokes compile 3 times, where we should end up:
(1) first run needs to compile, 0/0
(2) second run needs to compile, 0/1
(3) third run re-uses the same graph as (1), and should see 0/0 in the profiler.
But when print the profile events, I get:
['0/0', '0/1', '0/1']
It looks like we are setting up the "compile id context" at the end of compile time here - instead, shouldn't we be setting this context on invocation of a compiled artifact (so we can find the given compiled artifact's id and set the context accordingly?)
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @robieta @chaekit @aaronenyeshi @guotuofeng @guyang3532 @dzhulgakov @davidberard98 @briancoutinho @sraikund16 @sanrise @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @rec