Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Aug 9, 2024

Stack from ghstack (oldest at bottom):

Partially addresses #132939. Adds the AOT ID after the CompiledFunctionBackward annotation in verbose compiled autograd logging

default (no change):
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp8WCSLf/dedicated_log_torch_trace_xw3ktsi_.log/index.html

TORCH_LOGS="compiled_autograd_verbose":
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp8WCSLf/dedicated_log_torch_trace_gsc9q_43.log/index.html

# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:361 in set_node_origin, code: CompiledFunctionBackward1 (NodeCall 2)
clone: "f32[4]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
cos: "f32[4]" = torch.ops.aten.cos.default(getitem_1);  getitem_1 = None
mul: "f32[4]" = torch.ops.aten.mul.Tensor(clone, cos);  clone = cos = None

# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:361 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 3)
cos_1: "f32[4]" = torch.ops.aten.cos.default(getitem_2)
mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/133115

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 7c48ddd with merge base e7b870c (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xmfan added a commit that referenced this pull request Aug 9, 2024
Partially addresses #132939

default (no change):
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp8WCSLf/dedicated_log_torch_trace_xw3ktsi_.log/index.html

TORCH_LOGS="compiled_autograd_verbose":
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp8WCSLf/dedicated_log_torch_trace_gsc9q_43.log/index.html

```python
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:361 in set_node_origin, code: CompiledFunctionBackward1 (NodeCall 2)
clone: "f32[4]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
cos: "f32[4]" = torch.ops.aten.cos.default(getitem_1);  getitem_1 = None
mul: "f32[4]" = torch.ops.aten.mul.Tensor(clone, cos);  clone = cos = None

# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:361 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 3)
cos_1: "f32[4]" = torch.ops.aten.cos.default(getitem_2)
mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Aug 9, 2024
@xmfan xmfan marked this pull request as ready for review August 9, 2024 20:15
@xmfan xmfan requested review from bdhirsh and jansel August 9, 2024 20:15
yield model[2].bias.grad

logs, ctx = logs_to_string(
torch._dynamo.compiled_autograd.__name__, "compiled_autograd_verbose"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we always enable it instead of only under compiled_autograd_verbose? I believe this feature would be very useful for debugging internal jobs in the future, and if this aot id logging is on by default then we don't need to ask user to go back and run the job again with verbose enabled. (Or alternatively we can also ask user to always turn on compiled_autograd_verbose)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this one is not on the hot path so we can do that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also just thinking:

(1) do compiled autograd logs gets turned on for most jobs internally?

(2) is there anything else we can put into tlparse from compiled autograd that we aren't right now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for 1, tlparse seems to have non-verbose logs by default

2, there's probably a bunch of cross-artifacts analysis we can do in tlparse. but i feel like adding logs to the compiled autograd artifacts is sufficient for the problems we're seeing rn

Copy link
Contributor

@yf225 yf225 Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh another issue we saw is that when going from CA Dynamo graph to AOTAutograd joint graph, the source of each FX node in AOTAutograd joint graph is lost (i.e. we don't know which Dynamo graph FX node it comes from anymore). Curious is there a way to preserve this information?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm do you have an example, or a repro you can file?

I tried running the repro below, and when I the logs from TORCH_LOGS="aot,+compiled_autograd", it seems like the stacktraces are lost when compiled autograd first gets a graph, before it sends it to AOTDispatcher. log output: P1528486054

code:

import torch


@torch.compile
def f(x):
    return torch.matmul(x, x).sin()

x = torch.randn(4, 4, requires_grad=True)
with torch._dynamo.utils.maybe_enable_compiled_autograd(True):
    out = f(x)
    out.sum().backward()

Copy link
Member Author

@xmfan xmfan Aug 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to use compiled_autograd_verbose instead of +compiled_autograd (it's a separate artifact since artifacts don't support verbosity levels)

e.g. P1528511564

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed here #133567, I'll take a look

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to fix it #133574

Partially addresses #132939. Adds the AOT ID after the CompiledFunctionBackward annotation in verbose compiled autograd logging

default (no change):
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp8WCSLf/dedicated_log_torch_trace_xw3ktsi_.log/index.html

TORCH_LOGS="compiled_autograd_verbose":
https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp8WCSLf/dedicated_log_torch_trace_gsc9q_43.log/index.html

```python
# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:361 in set_node_origin, code: CompiledFunctionBackward1 (NodeCall 2)
clone: "f32[4]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
cos: "f32[4]" = torch.ops.aten.cos.default(getitem_1);  getitem_1 = None
mul: "f32[4]" = torch.ops.aten.mul.Tensor(clone, cos);  clone = cos = None

# File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:361 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 3)
cos_1: "f32[4]" = torch.ops.aten.cos.default(getitem_2)
mul_1: "f32[4]" = torch.ops.aten.mul.Tensor(mul, cos_1);  mul = cos_1 = None
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 17, 2024
FIXES #132939

Compiled autograd's trace of the AOT backward may result in some additional ops e.g. clone to make contiguous, trace_wrapped HOPs, so the graphs may be slightly offset from each other

hf_Whisper example: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpNv89Pu/index.html
fsdp2 example: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpPdKssS/rank_0/index.html
Unit test example: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpvoQsnl/index.html
```python
 ===== Compiled autograd graph =====
 <eval_with_key>.14 class CompiledAutograd(torch.nn.Module):
    def forward(self, inputs, sizes, scalars, hooks):
        # No stacktrace found for following nodes
        getitem: "f32[]cpu" = inputs[0]
        aot1_primals_1: "f32[4]cpu" = inputs[1]
        aot1_primals_2: "f32[4]cpu" = inputs[2]
        aot0_sin: "f32[4]cpu" = inputs[3]
        aot0_cos: "f32[4]cpu" = inputs[4]
        getitem_5: "f32[4]cpu" = inputs[5];  inputs = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: SumBackward0 (NodeCall 1)
        expand: "f32[4]cpu" = torch.ops.aten.expand.default(getitem, [4]);  getitem = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: CompiledFunctionBackward1 (NodeCall 2)
        aot1_tangents_1: "f32[4]cpu" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        aot1_sin_1: "f32[4]cpu" = torch.ops.aten.sin.default(aot1_primals_2);  aot1_primals_2 = None
        aot1_neg: "f32[4]cpu" = torch.ops.aten.neg.default(aot1_sin_1);  aot1_sin_1 = None
        aot0_tangents_2: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot1_tangents_1, aot1_neg);  aot1_neg = None
        aot1_cos_1: "f32[4]cpu" = torch.ops.aten.cos.default(aot1_primals_1);  aot1_primals_1 = None
        aot0_tangents_1: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot1_tangents_1, aot1_cos_1);  aot1_tangents_1 = aot1_cos_1 = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 3)
        aot0_neg: "f32[4]cpu" = torch.ops.aten.neg.default(aot0_sin);  aot0_sin = None
        aot0_mul: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_neg);  aot0_tangents_2 = aot0_neg = None
        aot0_mul_1: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_cos);  aot0_tangents_1 = aot0_cos = None
        aot0_add: "f32[4]cpu" = torch.ops.aten.add.Tensor(aot0_mul, aot0_mul_1);  aot0_mul = aot0_mul_1 = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 4)
        accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_5, aot0_add);  getitem_5 = aot0_add = accumulate_grad_ = None
        _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub();  _exec_final_callbacks_stub = None
        return []
```

where aot1 is
```python
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4][1]cpu", primals_2: "f32[4][1]cpu", tangents_1: "f32[4][1]cpu"):
         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2233 in torch_dynamo_resume_in_f_at_2232, code: return tmp1.sin() + tmp2.cos()
        sin_1: "f32[4][1]cpu" = torch.ops.aten.sin.default(primals_2);  primals_2 = None
        neg: "f32[4][1]cpu" = torch.ops.aten.neg.default(sin_1);  sin_1 = None
        mul: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_1, neg);  neg = None
        cos_1: "f32[4][1]cpu" = torch.ops.aten.cos.default(primals_1);  primals_1 = None
        mul_1: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_1, cos_1);  tangents_1 = cos_1 = None
        return (mul_1, mul)
```

and aot0 is
```python
class GraphModule(torch.nn.Module):
    def forward(self, sin: "f32[4][1]cpu", cos: "f32[4][1]cpu", tangents_1: "f32[4][1]cpu", tangents_2: "f32[4][1]cpu"):
         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2231 in f, code: tmp2 = x.cos()
        neg: "f32[4][1]cpu" = torch.ops.aten.neg.default(sin);  sin = None
        mul: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_2, neg);  tangents_2 = neg = None

         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2230 in f, code: tmp1 = x.sin()
        mul_1: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2230 in f, code: tmp1 = x.sin()
        add: "f32[4][1]cpu" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        return (add,)
```

Pull Request resolved: #133148
Approved by: https://github.com/jansel
ghstack dependencies: #133115
@github-actions github-actions bot deleted the gh/xmfan/75/head branch September 17, 2024 01:57
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 26, 2024
…ch#133148)

FIXES pytorch#132939

Compiled autograd's trace of the AOT backward may result in some additional ops e.g. clone to make contiguous, trace_wrapped HOPs, so the graphs may be slightly offset from each other

hf_Whisper example: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpNv89Pu/index.html
fsdp2 example: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpPdKssS/rank_0/index.html
Unit test example: https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpvoQsnl/index.html
```python
 ===== Compiled autograd graph =====
 <eval_with_key>.14 class CompiledAutograd(torch.nn.Module):
    def forward(self, inputs, sizes, scalars, hooks):
        # No stacktrace found for following nodes
        getitem: "f32[]cpu" = inputs[0]
        aot1_primals_1: "f32[4]cpu" = inputs[1]
        aot1_primals_2: "f32[4]cpu" = inputs[2]
        aot0_sin: "f32[4]cpu" = inputs[3]
        aot0_cos: "f32[4]cpu" = inputs[4]
        getitem_5: "f32[4]cpu" = inputs[5];  inputs = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: SumBackward0 (NodeCall 1)
        expand: "f32[4]cpu" = torch.ops.aten.expand.default(getitem, [4]);  getitem = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: CompiledFunctionBackward1 (NodeCall 2)
        aot1_tangents_1: "f32[4]cpu" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        aot1_sin_1: "f32[4]cpu" = torch.ops.aten.sin.default(aot1_primals_2);  aot1_primals_2 = None
        aot1_neg: "f32[4]cpu" = torch.ops.aten.neg.default(aot1_sin_1);  aot1_sin_1 = None
        aot0_tangents_2: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot1_tangents_1, aot1_neg);  aot1_neg = None
        aot1_cos_1: "f32[4]cpu" = torch.ops.aten.cos.default(aot1_primals_1);  aot1_primals_1 = None
        aot0_tangents_1: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot1_tangents_1, aot1_cos_1);  aot1_tangents_1 = aot1_cos_1 = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: CompiledFunctionBackward0 (NodeCall 3)
        aot0_neg: "f32[4]cpu" = torch.ops.aten.neg.default(aot0_sin);  aot0_sin = None
        aot0_mul: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_neg);  aot0_tangents_2 = aot0_neg = None
        aot0_mul_1: "f32[4]cpu" = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_cos);  aot0_tangents_1 = aot0_cos = None
        aot0_add: "f32[4]cpu" = torch.ops.aten.add.Tensor(aot0_mul, aot0_mul_1);  aot0_mul = aot0_mul_1 = None

         # File: /data/users/xmfan/a/pytorch/torch/_dynamo/compiled_autograd.py:444 in set_node_origin, code: torch::autograd::AccumulateGrad (NodeCall 4)
        accumulate_grad_ = torch.ops.inductor.accumulate_grad_.default(getitem_5, aot0_add);  getitem_5 = aot0_add = accumulate_grad_ = None
        _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub();  _exec_final_callbacks_stub = None
        return []
```

where aot1 is
```python
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4][1]cpu", primals_2: "f32[4][1]cpu", tangents_1: "f32[4][1]cpu"):
         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2233 in torch_dynamo_resume_in_f_at_2232, code: return tmp1.sin() + tmp2.cos()
        sin_1: "f32[4][1]cpu" = torch.ops.aten.sin.default(primals_2);  primals_2 = None
        neg: "f32[4][1]cpu" = torch.ops.aten.neg.default(sin_1);  sin_1 = None
        mul: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_1, neg);  neg = None
        cos_1: "f32[4][1]cpu" = torch.ops.aten.cos.default(primals_1);  primals_1 = None
        mul_1: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_1, cos_1);  tangents_1 = cos_1 = None
        return (mul_1, mul)
```

and aot0 is
```python
class GraphModule(torch.nn.Module):
    def forward(self, sin: "f32[4][1]cpu", cos: "f32[4][1]cpu", tangents_1: "f32[4][1]cpu", tangents_2: "f32[4][1]cpu"):
         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2231 in f, code: tmp2 = x.cos()
        neg: "f32[4][1]cpu" = torch.ops.aten.neg.default(sin);  sin = None
        mul: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_2, neg);  tangents_2 = neg = None

         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2230 in f, code: tmp1 = x.sin()
        mul_1: "f32[4][1]cpu" = torch.ops.aten.mul.Tensor(tangents_1, cos);  tangents_1 = cos = None

         # File: /data/users/xmfan/a/pytorch/test/inductor/test_compiled_autograd.py:2230 in f, code: tmp1 = x.sin()
        add: "f32[4][1]cpu" = torch.ops.aten.add.Tensor(mul, mul_1);  mul = mul_1 = None
        return (add,)
```

Pull Request resolved: pytorch#133148
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#133115
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants