Skip to content

[compile][transformers] Recompilation with mark_static_address with cudagraphs #156377

@anijain2305

Description

@anijain2305

🐛 Describe the bug

The recompilation happens because we ID_MATCH on the tensors with mark_static_address. I think this is done because cudagraphs can then don't have to copy them into a separate location before running the cudagraph recording. And Dynamo wants to reflect that mapping by guarding on ID.

This makes sense. But I was wondering - if we can avoid this recompilation? I think we already have logic that allows us multiple cudagraphs for the same inductor graph. Is it possible to reuse that here? Basically re-record if the address changes.

This will hugely benefit transformers models where we could move compile to a smaller region and get almost as good as speedup as full model with very small compile time.

import torch

# Make dynamo cache size 1 to fail on a recompilation
torch._dynamo.config.recompile_limit = 1

@torch.compile(mode="reduce-overhead", fullgraph=True)
def fn(x, static_y):
    return x + static_y

x = torch.randn(4)
static_y1 = torch.randn(4)
static_y2 = torch.randn(4)
torch._dynamo.mark_static_address(static_y1)
torch._dynamo.mark_static_address(static_y2)

fn(x, static_y1)
fn(x, static_y2)

Error logs

No response

Versions

N/A

cc @mcarilli @ezyang @eellison @penguinwu @BoyuanFeng @chauhang

Metadata

Metadata

Assignees

Labels

module: cuda graphsAbility to capture and then replay streams of CUDA kernelsoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions