Skip to content

profiler for PT2 can give wrong compilation frame ID #136235

@bdhirsh

Description

@bdhirsh

context here: https://github.com/pytorch/pytorch/pull/132765/files#r1764212475

repro:

@torch.compile(backend="aot_eager", dynamic=True)
def f(x):
    if x.shape[0] > 5:
        return x.sin()
    else:
        return x.cos()

x1 = torch.randn(4)
x2 = torch.randn(6)
x3 = torch.randn(3)

with torch.profiler.profile(record_shapes=True) as prof:
    out1 = f(x1)
    out2 = f(x2)
    out3 = f(x3)

print([e.kwinputs['context'] for e in prof.events() if 'Compiled' in e.name])

This invokes compile 3 times, where we should end up:
(1) first run needs to compile, 0/0
(2) second run needs to compile, 0/1
(3) third run re-uses the same graph as (1), and should see 0/0 in the profiler.

But when print the profile events, I get:

['0/0', '0/1', '0/1']

It looks like we are setting up the "compile id context" at the end of compile time here - instead, shouldn't we be setting this context on invocation of a compiled artifact (so we can find the given compiled artifact's id and set the context accordingly?)

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @robieta @chaekit @aaronenyeshi @guotuofeng @guyang3532 @dzhulgakov @davidberard98 @briancoutinho @sraikund16 @sanrise @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @amjames @rec

Metadata

Metadata

Assignees

Labels

high prioritymodule: dynamooncall: profilerprofiler-related issues (cpu, gpu, kineto)oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions