Skip to content

Conversation

@sraikund16
Copy link
Contributor

@sraikund16 sraikund16 commented Aug 6, 2024

Summary:
We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse.
This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context".

Test Plan:
Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling.

Differential Revision: D60803317

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132765

Note: Links to docs will display an error until the docs builds have been completed.

❌ 31 New Failures

As of commit 474f822 with merge base a23dae2 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@sraikund16 sraikund16 marked this pull request as draft August 6, 2024 17:09
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@sraikund16 sraikund16 changed the title First attempt at adding source information context to traces [PT2/Profiler] Add Context Info to Torch-Compiled Regions Aug 8, 2024
@sraikund16 sraikund16 marked this pull request as ready for review August 8, 2024 16:50
@davidberard98
Copy link
Contributor

cc @williamwen42 @yanboliang for the dynamo parts

also, if you can get a local dynamo benchmark setup running, if you can test the impact this has on the BERT_pytorch model, that would be great! or if not, you can try with the performance benchmark dashboard (https://hud.pytorch.org/benchmark/compilers). This would be a good quick test to see if there's any perf impact from this change.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

…2765)

Summary:
We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse.

This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context".

Pull Request resolved: pytorch#132765

Test Plan: Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling.

Reviewed By: anijain2305

Differential Revision: D60803317
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D60803317

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Sep 13, 2024
Summary:
We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse.
This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context".

Test Plan:
Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling.

Differential Revision: D60803317

Pull Request resolved: #132765
Approved by: https://github.com/anijain2305
finally:
if context is not None:
if context.compile_id is not None:
set_context_frame(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(btw this feature is awesome 😄 )

side car comment, maybe cc @anijain2305 - won't this function get called (and set the context) during compilation? It seems like we would want this state to be set to the "previously compiled frame" every time a given compiled object is invoked at runtime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah, repro here:

@torch.compile(backend="aot_eager", dynamic=True)
def f(x):
    if x.shape[0] > 5:
        return x.sin()
    else:
        return x.cos()

x1 = torch.randn(4)
x2 = torch.randn(6)
x3 = torch.randn(3)

with torch.profiler.profile(record_shapes=True) as prof:
    out1 = f(x1)
    out2 = f(x2)
    out3 = f(x3)

print([e.kwinputs['context'] for e in prof.events() if 'Compiled' in e.name])

this should print:

['0/0', '0/1', '0/0']

since we are re-running the first compiled object on the third case. But instead, I get:

['0/0', '0/1', '0/1']

I can file an issue

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
…2765)

Summary:
We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse.
This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context".

Test Plan:
Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling.

Differential Revision: D60803317

Pull Request resolved: pytorch#132765
Approved by: https://github.com/anijain2305
PyObject* guard_error_hook = NULL;
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";

static char compile_context[MAX_COMPILE_CONTEXT_SIZE];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should NOT be maintaining global state like this. Superficially, it is not thread safe. But more conceptually, compile id of a compiled object is associated with the compile object itself, it's not something that should be set at runtime. When you compile some code, the compile product is what stores the compile id (0/0, whatever), and you should be pulling out the compile id from the compile product itself.

{"unsupported", unsupported, METH_VARARGS, NULL},
{"skip_code", skip_code, METH_O, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
{"set_context_frame", set_context_frame, METH_O, NULL},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you get rid of global state, you can delete this function entirely. Instead...

return;
}
kwinputs_ = *kwargs;
kwinputs_ = std::unordered_map<std::string, IValue>(*kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

PyCodeObject* code,
int throw_flag,
int free_vars_copied) {
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter("Torch-Compiled Region");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I want here, is I want a const char* passed in with the name of the compile region to this function, and then we use the old _pytorch_record_function_enter. Let me check the call sites to ensure you can do this...

@ezyang
Copy link
Contributor

ezyang commented Sep 20, 2024

Concretely, I think we should revert this PR redo it under a better design that doesn't have the problem @bdhirsh pointed out

@ezyang
Copy link
Contributor

ezyang commented Sep 20, 2024

@pytorchbot revert -c weird -m "implementation is not correct, needs full rewrite"

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@sraikund16 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Sep 20, 2024
zou3519 added a commit that referenced this pull request Sep 25, 2024
Revert "[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)"

This reverts commit e184391.

Revert "Fix clang-tidy warnings in torch/csrc/lazy (#134655)"

This reverts commit 0287146.

Revert "Remove duplicate line (#136383)"

This reverts commit 0b91e7e.

Revert "[TF32] Account for TF32 in `test_conv_double_backward` (#135716)"

This reverts commit 29f7b8d.

Revert "Fix `Vectorized<double>::next_after` SVE compilation (#136388)"

This reverts commit 7936584.

Revert "Upgrade pybind11 API calls for 3.13t (#136370)"

This reverts commit 067d203.

Revert "[AOTI][Tooling] Filter out kernels based off lowercase names (#135395)"

This reverts commit 1a10751.

Revert "Add decomps for max_unpool (#133146)"

This reverts commit 0c936c3.

Revert "add TORCH_CUDA_CPP_API for AutoNcclGroup (#130012)"

This reverts commit 293fccf.

Revert "Use cpython declaration of _PyWeakref_ClearRef (#136300)"

This reverts commit d2455b9.

Revert "fix mypi in utils/_sympy/functions.py (#136339)"

This reverts commit 7f9c064.

Revert "[Inductor] Fix test_profiler_mark_wrapper_call_cuda_cuda_wrapper (#136356)"

This reverts commit f53a0f9.

Revert "Add more distributed examples (#130427)"

This reverts commit 5997354.

Revert "return instead of using skipTest (#136244)"

This reverts commit 29affa6.

Reapply "[PT2/Profiler] Add Context Info to Torch-Compiled Regions (#132765)"

This reverts commit 783c5ba.

Revert "Enable torch build with SLEEF on ARM by default (#133339)"

This reverts commit 4842f0f.

Revert "[inductor] Relax the conditions for loop split (#135335)"

This reverts commit 687e5cf.

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Sep 25, 2024
Revert "[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)"

This reverts commit e184391.

Revert "Fix clang-tidy warnings in torch/csrc/lazy (#134655)"

This reverts commit 0287146.

Revert "Remove duplicate line (#136383)"

This reverts commit 0b91e7e.

Revert "[TF32] Account for TF32 in `test_conv_double_backward` (#135716)"

This reverts commit 29f7b8d.

Revert "Fix `Vectorized<double>::next_after` SVE compilation (#136388)"

This reverts commit 7936584.

Revert "Upgrade pybind11 API calls for 3.13t (#136370)"

This reverts commit 067d203.

Revert "[AOTI][Tooling] Filter out kernels based off lowercase names (#135395)"

This reverts commit 1a10751.

Revert "Add decomps for max_unpool (#133146)"

This reverts commit 0c936c3.

Revert "add TORCH_CUDA_CPP_API for AutoNcclGroup (#130012)"

This reverts commit 293fccf.

Revert "Use cpython declaration of _PyWeakref_ClearRef (#136300)"

This reverts commit d2455b9.

Revert "fix mypi in utils/_sympy/functions.py (#136339)"

This reverts commit 7f9c064.

Revert "[Inductor] Fix test_profiler_mark_wrapper_call_cuda_cuda_wrapper (#136356)"

This reverts commit f53a0f9.

Revert "Add more distributed examples (#130427)"

This reverts commit 5997354.

Revert "return instead of using skipTest (#136244)"

This reverts commit 29affa6.

Reapply "[PT2/Profiler] Add Context Info to Torch-Compiled Regions (#132765)"

This reverts commit 783c5ba.

Revert "Enable torch build with SLEEF on ARM by default (#133339)"

This reverts commit 4842f0f.

Revert "[inductor] Relax the conditions for loop split (#135335)"

This reverts commit 687e5cf.

ghstack-source-id: b0fb91e
Pull Request resolved: #136668
@sraikund16 sraikund16 closed this Oct 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants